Add dedicated information panel to the UI (#137)

* Allow streaming to the chatbot and the information panel without threading
* Highlight evidence in a simple manner
This commit is contained in:
Duc Nguyen (john) 2024-01-25 19:07:53 +07:00 committed by GitHub
parent ebc61400d8
commit 513e86f490
6 changed files with 116 additions and 27 deletions

View File

@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Iterator
from typing import Iterator, Optional
from kotaemon.base.schema import Document
from theflow import Function, Node, Param, lazy
@ -31,6 +31,17 @@ class BaseComponent(Function):
return self.__call__(self.inflow.flow())
def set_output_queue(self, queue):
self._queue = queue
for name in self._ff_nodes:
node = getattr(self, name)
if isinstance(node, BaseComponent):
node.set_output_queue(queue)
def report_output(self, output: Optional[dict]):
if self._queue is not None:
self._queue.put_nowait(output)
@abstractmethod
def run(
self, *args, **kwargs

View File

@ -11,7 +11,7 @@ user_cache_dir.mkdir(parents=True, exist_ok=True)
COHERE_API_KEY = config("COHERE_API_KEY", default="")
# KH_MODE = "dev"
KH_MODE = "dev"
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
KH_DOCSTORE = {
"__type__": "kotaemon.storages.SimpleFileDocumentStore",

View File

@ -32,6 +32,15 @@ footer {
height: calc(100vh - 140px) !important;
}
#chat-info-panel {
max-height: calc(100vh - 140px) !important;
overflow-y: scroll !important;
}
.setting-answer-mode-description {
margin: 5px 5px 2px !important
}
mark {
background-color: #1496bb;
}

View File

@ -23,6 +23,9 @@ class ChatPage(BasePage):
self.report_issue = ReportIssue(self._app)
with gr.Column(scale=6):
self.chat_panel = ChatPanel(self._app)
with gr.Column(scale=3):
with gr.Accordion(label="Information panel", open=True):
self.info_panel = gr.Markdown(elem_id="chat-info-panel")
def on_register_events(self):
self.chat_panel.submit_btn.click(
@ -33,7 +36,11 @@ class ChatPage(BasePage):
self.data_source.files,
self._app.settings_state,
],
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self.info_panel,
],
).then(
fn=update_data_source,
inputs=[
@ -52,7 +59,11 @@ class ChatPage(BasePage):
self.data_source.files,
self._app.settings_state,
],
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self.info_panel,
],
).then(
fn=update_data_source,
inputs=[

View File

@ -1,3 +1,4 @@
import asyncio
import os
import tempfile
from copy import deepcopy
@ -116,24 +117,33 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
return pipeline
def chat_fn(chat_input, chat_history, files, settings):
async def chat_fn(chat_input, chat_history, files, settings):
"""Chat function"""
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
# construct the pipeline
pipeline = create_pipeline(settings, files)
pipeline.set_output_queue(queue)
text = ""
refs = []
for response in pipeline(chat_input):
if response.metadata.get("citation", None):
citation = response.metadata["citation"]
for idx, fact_with_evidence in enumerate(citation.answer):
quotes = fact_with_evidence.substring_quote
if quotes:
refs.append(
(None, f"***Reference {idx+1}***: {' ... '.join(quotes)}")
)
else:
text += response.text
asyncio.create_task(pipeline(chat_input))
text, refs = "", ""
yield "", chat_history + [(chat_input, text)] + refs
while True:
try:
response = queue.get_nowait()
except Exception:
await asyncio.sleep(0)
continue
if response is None:
break
if "output" in response:
text += response["output"]
if "evidence" in response:
refs += response["evidence"]
yield "", chat_history + [(chat_input, text)], refs
def is_liked(convo_id, liked: gr.LikeData):

View File

@ -1,3 +1,4 @@
import asyncio
import logging
import warnings
from collections import defaultdict
@ -275,7 +276,7 @@ class AnswerWithContextPipeline(BaseComponent):
system_prompt: str = ""
lang: str = "English" # support English and Japanese
def run(
async def run( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0
) -> Document | Iterator[Document]:
"""Answer the question based on the evidence
@ -318,12 +319,16 @@ class AnswerWithContextPipeline(BaseComponent):
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
# output = self.llm(messages).text
yield from self.llm(messages)
output = ""
for text in self.llm(messages):
output += text.text
self.report_output({"output": text.text})
await asyncio.sleep(0)
citation = self.citation_pipeline(context=evidence, question=question)
answer = Document(text="", metadata={"citation": citation})
yield answer
answer = Document(text=output, metadata={"citation": citation})
return answer
class FullQAPipeline(BaseComponent):
@ -337,13 +342,56 @@ class FullQAPipeline(BaseComponent):
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
def run(self, question: str, **kwargs) -> Iterator[Document]:
async def run(self, question: str, **kwargs) -> Document: # type: ignore
docs = self.retrieval_pipeline(text=question)
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = self.answering_pipeline(
answer = await self.answering_pipeline(
question=question, evidence=evidence, evidence_mode=evidence_mode
)
yield from answer # should be a generator
# prepare citation
from collections import defaultdict
spans = defaultdict(list)
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx >= 0:
spans[doc.doc_id].append(
{
"start": start_idx,
"end": start_idx + len(quote),
}
)
break
id2docs = {doc.doc_id: doc for doc in docs}
for id, ss in spans.items():
if not ss:
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
self.report_output(
{
"evidence": (
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text}"
"</details><br>"
)
}
)
self.report_output(None)
return answer
@classmethod
def get_pipeline(cls, settings, **kwargs):