Add dedicated information panel to the UI (#137)
* Allow streaming to the chatbot and the information panel without threading * Highlight evidence in a simple manner
This commit is contained in:
parent
ebc61400d8
commit
513e86f490
|
@ -1,5 +1,5 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Iterator
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
from kotaemon.base.schema import Document
|
from kotaemon.base.schema import Document
|
||||||
from theflow import Function, Node, Param, lazy
|
from theflow import Function, Node, Param, lazy
|
||||||
|
@ -31,6 +31,17 @@ class BaseComponent(Function):
|
||||||
|
|
||||||
return self.__call__(self.inflow.flow())
|
return self.__call__(self.inflow.flow())
|
||||||
|
|
||||||
|
def set_output_queue(self, queue):
|
||||||
|
self._queue = queue
|
||||||
|
for name in self._ff_nodes:
|
||||||
|
node = getattr(self, name)
|
||||||
|
if isinstance(node, BaseComponent):
|
||||||
|
node.set_output_queue(queue)
|
||||||
|
|
||||||
|
def report_output(self, output: Optional[dict]):
|
||||||
|
if self._queue is not None:
|
||||||
|
self._queue.put_nowait(output)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(
|
def run(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
|
|
|
@ -11,7 +11,7 @@ user_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
COHERE_API_KEY = config("COHERE_API_KEY", default="")
|
COHERE_API_KEY = config("COHERE_API_KEY", default="")
|
||||||
# KH_MODE = "dev"
|
KH_MODE = "dev"
|
||||||
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
|
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
|
||||||
KH_DOCSTORE = {
|
KH_DOCSTORE = {
|
||||||
"__type__": "kotaemon.storages.SimpleFileDocumentStore",
|
"__type__": "kotaemon.storages.SimpleFileDocumentStore",
|
||||||
|
|
|
@ -32,6 +32,15 @@ footer {
|
||||||
height: calc(100vh - 140px) !important;
|
height: calc(100vh - 140px) !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#chat-info-panel {
|
||||||
|
max-height: calc(100vh - 140px) !important;
|
||||||
|
overflow-y: scroll !important;
|
||||||
|
}
|
||||||
|
|
||||||
.setting-answer-mode-description {
|
.setting-answer-mode-description {
|
||||||
margin: 5px 5px 2px !important
|
margin: 5px 5px 2px !important
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mark {
|
||||||
|
background-color: #1496bb;
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,9 @@ class ChatPage(BasePage):
|
||||||
self.report_issue = ReportIssue(self._app)
|
self.report_issue = ReportIssue(self._app)
|
||||||
with gr.Column(scale=6):
|
with gr.Column(scale=6):
|
||||||
self.chat_panel = ChatPanel(self._app)
|
self.chat_panel = ChatPanel(self._app)
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
with gr.Accordion(label="Information panel", open=True):
|
||||||
|
self.info_panel = gr.Markdown(elem_id="chat-info-panel")
|
||||||
|
|
||||||
def on_register_events(self):
|
def on_register_events(self):
|
||||||
self.chat_panel.submit_btn.click(
|
self.chat_panel.submit_btn.click(
|
||||||
|
@ -33,7 +36,11 @@ class ChatPage(BasePage):
|
||||||
self.data_source.files,
|
self.data_source.files,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
],
|
],
|
||||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
outputs=[
|
||||||
|
self.chat_panel.text_input,
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
self.info_panel,
|
||||||
|
],
|
||||||
).then(
|
).then(
|
||||||
fn=update_data_source,
|
fn=update_data_source,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
@ -52,7 +59,11 @@ class ChatPage(BasePage):
|
||||||
self.data_source.files,
|
self.data_source.files,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
],
|
],
|
||||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
outputs=[
|
||||||
|
self.chat_panel.text_input,
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
self.info_panel,
|
||||||
|
],
|
||||||
).then(
|
).then(
|
||||||
fn=update_data_source,
|
fn=update_data_source,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
@ -116,24 +117,33 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
def chat_fn(chat_input, chat_history, files, settings):
|
async def chat_fn(chat_input, chat_history, files, settings):
|
||||||
|
"""Chat function"""
|
||||||
|
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
||||||
|
|
||||||
|
# construct the pipeline
|
||||||
pipeline = create_pipeline(settings, files)
|
pipeline = create_pipeline(settings, files)
|
||||||
|
pipeline.set_output_queue(queue)
|
||||||
|
|
||||||
text = ""
|
asyncio.create_task(pipeline(chat_input))
|
||||||
refs = []
|
text, refs = "", ""
|
||||||
for response in pipeline(chat_input):
|
|
||||||
if response.metadata.get("citation", None):
|
|
||||||
citation = response.metadata["citation"]
|
|
||||||
for idx, fact_with_evidence in enumerate(citation.answer):
|
|
||||||
quotes = fact_with_evidence.substring_quote
|
|
||||||
if quotes:
|
|
||||||
refs.append(
|
|
||||||
(None, f"***Reference {idx+1}***: {' ... '.join(quotes)}")
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text += response.text
|
|
||||||
|
|
||||||
yield "", chat_history + [(chat_input, text)] + refs
|
while True:
|
||||||
|
try:
|
||||||
|
response = queue.get_nowait()
|
||||||
|
except Exception:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
if "output" in response:
|
||||||
|
text += response["output"]
|
||||||
|
if "evidence" in response:
|
||||||
|
refs += response["evidence"]
|
||||||
|
|
||||||
|
yield "", chat_history + [(chat_input, text)], refs
|
||||||
|
|
||||||
|
|
||||||
def is_liked(convo_id, liked: gr.LikeData):
|
def is_liked(convo_id, liked: gr.LikeData):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
@ -275,7 +276,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
lang: str = "English" # support English and Japanese
|
lang: str = "English" # support English and Japanese
|
||||||
|
|
||||||
def run(
|
async def run( # type: ignore
|
||||||
self, question: str, evidence: str, evidence_mode: int = 0
|
self, question: str, evidence: str, evidence_mode: int = 0
|
||||||
) -> Document | Iterator[Document]:
|
) -> Document | Iterator[Document]:
|
||||||
"""Answer the question based on the evidence
|
"""Answer the question based on the evidence
|
||||||
|
@ -318,12 +319,16 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
SystemMessage(content="You are a helpful assistant"),
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
HumanMessage(content=prompt),
|
HumanMessage(content=prompt),
|
||||||
]
|
]
|
||||||
# output = self.llm(messages).text
|
output = ""
|
||||||
yield from self.llm(messages)
|
for text in self.llm(messages):
|
||||||
|
output += text.text
|
||||||
|
self.report_output({"output": text.text})
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
citation = self.citation_pipeline(context=evidence, question=question)
|
citation = self.citation_pipeline(context=evidence, question=question)
|
||||||
answer = Document(text="", metadata={"citation": citation})
|
answer = Document(text=output, metadata={"citation": citation})
|
||||||
yield answer
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
class FullQAPipeline(BaseComponent):
|
class FullQAPipeline(BaseComponent):
|
||||||
|
@ -337,13 +342,56 @@ class FullQAPipeline(BaseComponent):
|
||||||
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
||||||
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
||||||
|
|
||||||
def run(self, question: str, **kwargs) -> Iterator[Document]:
|
async def run(self, question: str, **kwargs) -> Document: # type: ignore
|
||||||
docs = self.retrieval_pipeline(text=question)
|
docs = self.retrieval_pipeline(text=question)
|
||||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||||
answer = self.answering_pipeline(
|
answer = await self.answering_pipeline(
|
||||||
question=question, evidence=evidence, evidence_mode=evidence_mode
|
question=question, evidence=evidence, evidence_mode=evidence_mode
|
||||||
)
|
)
|
||||||
yield from answer # should be a generator
|
|
||||||
|
# prepare citation
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
spans = defaultdict(list)
|
||||||
|
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||||
|
for quote in fact_with_evidence.substring_quote:
|
||||||
|
for doc in docs:
|
||||||
|
start_idx = doc.text.find(quote)
|
||||||
|
if start_idx >= 0:
|
||||||
|
spans[doc.doc_id].append(
|
||||||
|
{
|
||||||
|
"start": start_idx,
|
||||||
|
"end": start_idx + len(quote),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
id2docs = {doc.doc_id: doc for doc in docs}
|
||||||
|
for id, ss in spans.items():
|
||||||
|
if not ss:
|
||||||
|
continue
|
||||||
|
ss = sorted(ss, key=lambda x: x["start"])
|
||||||
|
text = id2docs[id].text[: ss[0]["start"]]
|
||||||
|
for idx, span in enumerate(ss):
|
||||||
|
text += (
|
||||||
|
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
|
||||||
|
)
|
||||||
|
if idx < len(ss) - 1:
|
||||||
|
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
||||||
|
text += id2docs[id].text[ss[-1]["end"] :]
|
||||||
|
self.report_output(
|
||||||
|
{
|
||||||
|
"evidence": (
|
||||||
|
"<details>"
|
||||||
|
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||||
|
f"{text}"
|
||||||
|
"</details><br>"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.report_output(None)
|
||||||
|
return answer
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pipeline(cls, settings, **kwargs):
|
def get_pipeline(cls, settings, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user