Add dedicated information panel to the UI (#137)
* Allow streaming to the chatbot and the information panel without threading * Highlight evidence in a simple manner
This commit is contained in:
parent
ebc61400d8
commit
513e86f490
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from theflow import Function, Node, Param, lazy
|
||||
|
@ -31,6 +31,17 @@ class BaseComponent(Function):
|
|||
|
||||
return self.__call__(self.inflow.flow())
|
||||
|
||||
def set_output_queue(self, queue):
|
||||
self._queue = queue
|
||||
for name in self._ff_nodes:
|
||||
node = getattr(self, name)
|
||||
if isinstance(node, BaseComponent):
|
||||
node.set_output_queue(queue)
|
||||
|
||||
def report_output(self, output: Optional[dict]):
|
||||
if self._queue is not None:
|
||||
self._queue.put_nowait(output)
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self, *args, **kwargs
|
||||
|
|
|
@ -11,7 +11,7 @@ user_cache_dir.mkdir(parents=True, exist_ok=True)
|
|||
|
||||
|
||||
COHERE_API_KEY = config("COHERE_API_KEY", default="")
|
||||
# KH_MODE = "dev"
|
||||
KH_MODE = "dev"
|
||||
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
|
||||
KH_DOCSTORE = {
|
||||
"__type__": "kotaemon.storages.SimpleFileDocumentStore",
|
||||
|
|
|
@ -32,6 +32,15 @@ footer {
|
|||
height: calc(100vh - 140px) !important;
|
||||
}
|
||||
|
||||
#chat-info-panel {
|
||||
max-height: calc(100vh - 140px) !important;
|
||||
overflow-y: scroll !important;
|
||||
}
|
||||
|
||||
.setting-answer-mode-description {
|
||||
margin: 5px 5px 2px !important
|
||||
}
|
||||
|
||||
mark {
|
||||
background-color: #1496bb;
|
||||
}
|
||||
|
|
|
@ -23,6 +23,9 @@ class ChatPage(BasePage):
|
|||
self.report_issue = ReportIssue(self._app)
|
||||
with gr.Column(scale=6):
|
||||
self.chat_panel = ChatPanel(self._app)
|
||||
with gr.Column(scale=3):
|
||||
with gr.Accordion(label="Information panel", open=True):
|
||||
self.info_panel = gr.Markdown(elem_id="chat-info-panel")
|
||||
|
||||
def on_register_events(self):
|
||||
self.chat_panel.submit_btn.click(
|
||||
|
@ -33,7 +36,11 @@ class ChatPage(BasePage):
|
|||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
||||
outputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
],
|
||||
).then(
|
||||
fn=update_data_source,
|
||||
inputs=[
|
||||
|
@ -52,7 +59,11 @@ class ChatPage(BasePage):
|
|||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
||||
outputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
],
|
||||
).then(
|
||||
fn=update_data_source,
|
||||
inputs=[
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
@ -116,24 +117,33 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
|||
return pipeline
|
||||
|
||||
|
||||
def chat_fn(chat_input, chat_history, files, settings):
|
||||
async def chat_fn(chat_input, chat_history, files, settings):
|
||||
"""Chat function"""
|
||||
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
||||
|
||||
# construct the pipeline
|
||||
pipeline = create_pipeline(settings, files)
|
||||
pipeline.set_output_queue(queue)
|
||||
|
||||
text = ""
|
||||
refs = []
|
||||
for response in pipeline(chat_input):
|
||||
if response.metadata.get("citation", None):
|
||||
citation = response.metadata["citation"]
|
||||
for idx, fact_with_evidence in enumerate(citation.answer):
|
||||
quotes = fact_with_evidence.substring_quote
|
||||
if quotes:
|
||||
refs.append(
|
||||
(None, f"***Reference {idx+1}***: {' ... '.join(quotes)}")
|
||||
)
|
||||
else:
|
||||
text += response.text
|
||||
asyncio.create_task(pipeline(chat_input))
|
||||
text, refs = "", ""
|
||||
|
||||
yield "", chat_history + [(chat_input, text)] + refs
|
||||
while True:
|
||||
try:
|
||||
response = queue.get_nowait()
|
||||
except Exception:
|
||||
await asyncio.sleep(0)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
break
|
||||
|
||||
if "output" in response:
|
||||
text += response["output"]
|
||||
if "evidence" in response:
|
||||
refs += response["evidence"]
|
||||
|
||||
yield "", chat_history + [(chat_input, text)], refs
|
||||
|
||||
|
||||
def is_liked(convo_id, liked: gr.LikeData):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
|
@ -275,7 +276,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
system_prompt: str = ""
|
||||
lang: str = "English" # support English and Japanese
|
||||
|
||||
def run(
|
||||
async def run( # type: ignore
|
||||
self, question: str, evidence: str, evidence_mode: int = 0
|
||||
) -> Document | Iterator[Document]:
|
||||
"""Answer the question based on the evidence
|
||||
|
@ -318,12 +319,16 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
SystemMessage(content="You are a helpful assistant"),
|
||||
HumanMessage(content=prompt),
|
||||
]
|
||||
# output = self.llm(messages).text
|
||||
yield from self.llm(messages)
|
||||
output = ""
|
||||
for text in self.llm(messages):
|
||||
output += text.text
|
||||
self.report_output({"output": text.text})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
citation = self.citation_pipeline(context=evidence, question=question)
|
||||
answer = Document(text="", metadata={"citation": citation})
|
||||
yield answer
|
||||
answer = Document(text=output, metadata={"citation": citation})
|
||||
|
||||
return answer
|
||||
|
||||
|
||||
class FullQAPipeline(BaseComponent):
|
||||
|
@ -337,13 +342,56 @@ class FullQAPipeline(BaseComponent):
|
|||
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
||||
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
||||
|
||||
def run(self, question: str, **kwargs) -> Iterator[Document]:
|
||||
async def run(self, question: str, **kwargs) -> Document: # type: ignore
|
||||
docs = self.retrieval_pipeline(text=question)
|
||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||
answer = self.answering_pipeline(
|
||||
answer = await self.answering_pipeline(
|
||||
question=question, evidence=evidence, evidence_mode=evidence_mode
|
||||
)
|
||||
yield from answer # should be a generator
|
||||
|
||||
# prepare citation
|
||||
from collections import defaultdict
|
||||
|
||||
spans = defaultdict(list)
|
||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||
for quote in fact_with_evidence.substring_quote:
|
||||
for doc in docs:
|
||||
start_idx = doc.text.find(quote)
|
||||
if start_idx >= 0:
|
||||
spans[doc.doc_id].append(
|
||||
{
|
||||
"start": start_idx,
|
||||
"end": start_idx + len(quote),
|
||||
}
|
||||
)
|
||||
break
|
||||
|
||||
id2docs = {doc.doc_id: doc for doc in docs}
|
||||
for id, ss in spans.items():
|
||||
if not ss:
|
||||
continue
|
||||
ss = sorted(ss, key=lambda x: x["start"])
|
||||
text = id2docs[id].text[: ss[0]["start"]]
|
||||
for idx, span in enumerate(ss):
|
||||
text += (
|
||||
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
|
||||
)
|
||||
if idx < len(ss) - 1:
|
||||
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
||||
text += id2docs[id].text[ss[-1]["end"] :]
|
||||
self.report_output(
|
||||
{
|
||||
"evidence": (
|
||||
"<details>"
|
||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||
f"{text}"
|
||||
"</details><br>"
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
self.report_output(None)
|
||||
return answer
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, settings, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue
Block a user