Add dedicated information panel to the UI (#137)

* Allow streaming to the chatbot and the information panel without threading
* Highlight evidence in a simple manner
This commit is contained in:
Duc Nguyen (john) 2024-01-25 19:07:53 +07:00 committed by GitHub
parent ebc61400d8
commit 513e86f490
6 changed files with 116 additions and 27 deletions

View File

@ -1,5 +1,5 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Iterator from typing import Iterator, Optional
from kotaemon.base.schema import Document from kotaemon.base.schema import Document
from theflow import Function, Node, Param, lazy from theflow import Function, Node, Param, lazy
@ -31,6 +31,17 @@ class BaseComponent(Function):
return self.__call__(self.inflow.flow()) return self.__call__(self.inflow.flow())
def set_output_queue(self, queue):
self._queue = queue
for name in self._ff_nodes:
node = getattr(self, name)
if isinstance(node, BaseComponent):
node.set_output_queue(queue)
def report_output(self, output: Optional[dict]):
if self._queue is not None:
self._queue.put_nowait(output)
@abstractmethod @abstractmethod
def run( def run(
self, *args, **kwargs self, *args, **kwargs

View File

@ -11,7 +11,7 @@ user_cache_dir.mkdir(parents=True, exist_ok=True)
COHERE_API_KEY = config("COHERE_API_KEY", default="") COHERE_API_KEY = config("COHERE_API_KEY", default="")
# KH_MODE = "dev" KH_MODE = "dev"
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}" KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
KH_DOCSTORE = { KH_DOCSTORE = {
"__type__": "kotaemon.storages.SimpleFileDocumentStore", "__type__": "kotaemon.storages.SimpleFileDocumentStore",

View File

@ -32,6 +32,15 @@ footer {
height: calc(100vh - 140px) !important; height: calc(100vh - 140px) !important;
} }
#chat-info-panel {
max-height: calc(100vh - 140px) !important;
overflow-y: scroll !important;
}
.setting-answer-mode-description { .setting-answer-mode-description {
margin: 5px 5px 2px !important margin: 5px 5px 2px !important
} }
mark {
background-color: #1496bb;
}

View File

@ -23,6 +23,9 @@ class ChatPage(BasePage):
self.report_issue = ReportIssue(self._app) self.report_issue = ReportIssue(self._app)
with gr.Column(scale=6): with gr.Column(scale=6):
self.chat_panel = ChatPanel(self._app) self.chat_panel = ChatPanel(self._app)
with gr.Column(scale=3):
with gr.Accordion(label="Information panel", open=True):
self.info_panel = gr.Markdown(elem_id="chat-info-panel")
def on_register_events(self): def on_register_events(self):
self.chat_panel.submit_btn.click( self.chat_panel.submit_btn.click(
@ -33,7 +36,11 @@ class ChatPage(BasePage):
self.data_source.files, self.data_source.files,
self._app.settings_state, self._app.settings_state,
], ],
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot], outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self.info_panel,
],
).then( ).then(
fn=update_data_source, fn=update_data_source,
inputs=[ inputs=[
@ -52,7 +59,11 @@ class ChatPage(BasePage):
self.data_source.files, self.data_source.files,
self._app.settings_state, self._app.settings_state,
], ],
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot], outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self.info_panel,
],
).then( ).then(
fn=update_data_source, fn=update_data_source,
inputs=[ inputs=[

View File

@ -1,3 +1,4 @@
import asyncio
import os import os
import tempfile import tempfile
from copy import deepcopy from copy import deepcopy
@ -116,24 +117,33 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
return pipeline return pipeline
def chat_fn(chat_input, chat_history, files, settings): async def chat_fn(chat_input, chat_history, files, settings):
"""Chat function"""
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
# construct the pipeline
pipeline = create_pipeline(settings, files) pipeline = create_pipeline(settings, files)
pipeline.set_output_queue(queue)
text = "" asyncio.create_task(pipeline(chat_input))
refs = [] text, refs = "", ""
for response in pipeline(chat_input):
if response.metadata.get("citation", None):
citation = response.metadata["citation"]
for idx, fact_with_evidence in enumerate(citation.answer):
quotes = fact_with_evidence.substring_quote
if quotes:
refs.append(
(None, f"***Reference {idx+1}***: {' ... '.join(quotes)}")
)
else:
text += response.text
yield "", chat_history + [(chat_input, text)] + refs while True:
try:
response = queue.get_nowait()
except Exception:
await asyncio.sleep(0)
continue
if response is None:
break
if "output" in response:
text += response["output"]
if "evidence" in response:
refs += response["evidence"]
yield "", chat_history + [(chat_input, text)], refs
def is_liked(convo_id, liked: gr.LikeData): def is_liked(convo_id, liked: gr.LikeData):

View File

@ -1,3 +1,4 @@
import asyncio
import logging import logging
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@ -275,7 +276,7 @@ class AnswerWithContextPipeline(BaseComponent):
system_prompt: str = "" system_prompt: str = ""
lang: str = "English" # support English and Japanese lang: str = "English" # support English and Japanese
def run( async def run( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0 self, question: str, evidence: str, evidence_mode: int = 0
) -> Document | Iterator[Document]: ) -> Document | Iterator[Document]:
"""Answer the question based on the evidence """Answer the question based on the evidence
@ -318,12 +319,16 @@ class AnswerWithContextPipeline(BaseComponent):
SystemMessage(content="You are a helpful assistant"), SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt), HumanMessage(content=prompt),
] ]
# output = self.llm(messages).text output = ""
yield from self.llm(messages) for text in self.llm(messages):
output += text.text
self.report_output({"output": text.text})
await asyncio.sleep(0)
citation = self.citation_pipeline(context=evidence, question=question) citation = self.citation_pipeline(context=evidence, question=question)
answer = Document(text="", metadata={"citation": citation}) answer = Document(text=output, metadata={"citation": citation})
yield answer
return answer
class FullQAPipeline(BaseComponent): class FullQAPipeline(BaseComponent):
@ -337,13 +342,56 @@ class FullQAPipeline(BaseComponent):
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
def run(self, question: str, **kwargs) -> Iterator[Document]: async def run(self, question: str, **kwargs) -> Document: # type: ignore
docs = self.retrieval_pipeline(text=question) docs = self.retrieval_pipeline(text=question)
evidence_mode, evidence = self.evidence_pipeline(docs).content evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = self.answering_pipeline( answer = await self.answering_pipeline(
question=question, evidence=evidence, evidence_mode=evidence_mode question=question, evidence=evidence, evidence_mode=evidence_mode
) )
yield from answer # should be a generator
# prepare citation
from collections import defaultdict
spans = defaultdict(list)
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx >= 0:
spans[doc.doc_id].append(
{
"start": start_idx,
"end": start_idx + len(quote),
}
)
break
id2docs = {doc.doc_id: doc for doc in docs}
for id, ss in spans.items():
if not ss:
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
self.report_output(
{
"evidence": (
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text}"
"</details><br>"
)
}
)
self.report_output(None)
return answer
@classmethod @classmethod
def get_pipeline(cls, settings, **kwargs): def get_pipeline(cls, settings, **kwargs):