diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py
index 8c24c14..2f8abf5 100644
--- a/libs/kotaemon/kotaemon/base/component.py
+++ b/libs/kotaemon/kotaemon/base/component.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import Iterator
+from typing import Iterator, Optional
from kotaemon.base.schema import Document
from theflow import Function, Node, Param, lazy
@@ -31,6 +31,17 @@ class BaseComponent(Function):
return self.__call__(self.inflow.flow())
+ def set_output_queue(self, queue):
+ self._queue = queue
+ for name in self._ff_nodes:
+ node = getattr(self, name)
+ if isinstance(node, BaseComponent):
+ node.set_output_queue(queue)
+
+ def report_output(self, output: Optional[dict]):
+ if self._queue is not None:
+ self._queue.put_nowait(output)
+
@abstractmethod
def run(
self, *args, **kwargs
diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py
index 74e5cbc..dbf19ba 100644
--- a/libs/ktem/flowsettings.py
+++ b/libs/ktem/flowsettings.py
@@ -11,7 +11,7 @@ user_cache_dir.mkdir(parents=True, exist_ok=True)
COHERE_API_KEY = config("COHERE_API_KEY", default="")
-# KH_MODE = "dev"
+KH_MODE = "dev"
KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}"
KH_DOCSTORE = {
"__type__": "kotaemon.storages.SimpleFileDocumentStore",
diff --git a/libs/ktem/ktem/assets/css/main.css b/libs/ktem/ktem/assets/css/main.css
index 15c2aa2..0ed480a 100644
--- a/libs/ktem/ktem/assets/css/main.css
+++ b/libs/ktem/ktem/assets/css/main.css
@@ -32,6 +32,15 @@ footer {
height: calc(100vh - 140px) !important;
}
+#chat-info-panel {
+ max-height: calc(100vh - 140px) !important;
+ overflow-y: scroll !important;
+}
+
.setting-answer-mode-description {
margin: 5px 5px 2px !important
}
+
+mark {
+ background-color: #1496bb;
+}
diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py
index e148c1d..8bd4217 100644
--- a/libs/ktem/ktem/pages/chat/__init__.py
+++ b/libs/ktem/ktem/pages/chat/__init__.py
@@ -23,6 +23,9 @@ class ChatPage(BasePage):
self.report_issue = ReportIssue(self._app)
with gr.Column(scale=6):
self.chat_panel = ChatPanel(self._app)
+ with gr.Column(scale=3):
+ with gr.Accordion(label="Information panel", open=True):
+ self.info_panel = gr.Markdown(elem_id="chat-info-panel")
def on_register_events(self):
self.chat_panel.submit_btn.click(
@@ -33,7 +36,11 @@ class ChatPage(BasePage):
self.data_source.files,
self._app.settings_state,
],
- outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
+ outputs=[
+ self.chat_panel.text_input,
+ self.chat_panel.chatbot,
+ self.info_panel,
+ ],
).then(
fn=update_data_source,
inputs=[
@@ -52,7 +59,11 @@ class ChatPage(BasePage):
self.data_source.files,
self._app.settings_state,
],
- outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
+ outputs=[
+ self.chat_panel.text_input,
+ self.chat_panel.chatbot,
+ self.info_panel,
+ ],
).then(
fn=update_data_source,
inputs=[
diff --git a/libs/ktem/ktem/pages/chat/events.py b/libs/ktem/ktem/pages/chat/events.py
index 9a3346d..18f6506 100644
--- a/libs/ktem/ktem/pages/chat/events.py
+++ b/libs/ktem/ktem/pages/chat/events.py
@@ -1,3 +1,4 @@
+import asyncio
import os
import tempfile
from copy import deepcopy
@@ -116,24 +117,33 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
return pipeline
-def chat_fn(chat_input, chat_history, files, settings):
+async def chat_fn(chat_input, chat_history, files, settings):
+ """Chat function"""
+ queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
+
+ # construct the pipeline
pipeline = create_pipeline(settings, files)
+ pipeline.set_output_queue(queue)
- text = ""
- refs = []
- for response in pipeline(chat_input):
- if response.metadata.get("citation", None):
- citation = response.metadata["citation"]
- for idx, fact_with_evidence in enumerate(citation.answer):
- quotes = fact_with_evidence.substring_quote
- if quotes:
- refs.append(
- (None, f"***Reference {idx+1}***: {' ... '.join(quotes)}")
- )
- else:
- text += response.text
+ asyncio.create_task(pipeline(chat_input))
+ text, refs = "", ""
- yield "", chat_history + [(chat_input, text)] + refs
+ while True:
+ try:
+ response = queue.get_nowait()
+ except Exception:
+ await asyncio.sleep(0)
+ continue
+
+ if response is None:
+ break
+
+ if "output" in response:
+ text += response["output"]
+ if "evidence" in response:
+ refs += response["evidence"]
+
+ yield "", chat_history + [(chat_input, text)], refs
def is_liked(convo_id, liked: gr.LikeData):
diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py
index 5cfc819..cab5063 100644
--- a/libs/ktem/ktem/reasoning/simple.py
+++ b/libs/ktem/ktem/reasoning/simple.py
@@ -1,3 +1,4 @@
+import asyncio
import logging
import warnings
from collections import defaultdict
@@ -275,7 +276,7 @@ class AnswerWithContextPipeline(BaseComponent):
system_prompt: str = ""
lang: str = "English" # support English and Japanese
- def run(
+ async def run( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0
) -> Document | Iterator[Document]:
"""Answer the question based on the evidence
@@ -318,12 +319,16 @@ class AnswerWithContextPipeline(BaseComponent):
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
- # output = self.llm(messages).text
- yield from self.llm(messages)
+ output = ""
+ for text in self.llm(messages):
+ output += text.text
+ self.report_output({"output": text.text})
+ await asyncio.sleep(0)
citation = self.citation_pipeline(context=evidence, question=question)
- answer = Document(text="", metadata={"citation": citation})
- yield answer
+ answer = Document(text=output, metadata={"citation": citation})
+
+ return answer
class FullQAPipeline(BaseComponent):
@@ -337,13 +342,56 @@ class FullQAPipeline(BaseComponent):
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
- def run(self, question: str, **kwargs) -> Iterator[Document]:
+ async def run(self, question: str, **kwargs) -> Document: # type: ignore
docs = self.retrieval_pipeline(text=question)
evidence_mode, evidence = self.evidence_pipeline(docs).content
- answer = self.answering_pipeline(
+ answer = await self.answering_pipeline(
question=question, evidence=evidence, evidence_mode=evidence_mode
)
- yield from answer # should be a generator
+
+ # prepare citation
+ from collections import defaultdict
+
+ spans = defaultdict(list)
+ for fact_with_evidence in answer.metadata["citation"].answer:
+ for quote in fact_with_evidence.substring_quote:
+ for doc in docs:
+ start_idx = doc.text.find(quote)
+ if start_idx >= 0:
+ spans[doc.doc_id].append(
+ {
+ "start": start_idx,
+ "end": start_idx + len(quote),
+ }
+ )
+ break
+
+ id2docs = {doc.doc_id: doc for doc in docs}
+ for id, ss in spans.items():
+ if not ss:
+ continue
+ ss = sorted(ss, key=lambda x: x["start"])
+ text = id2docs[id].text[: ss[0]["start"]]
+ for idx, span in enumerate(ss):
+ text += (
+ "" + id2docs[id].text[span["start"] : span["end"]] + ""
+ )
+ if idx < len(ss) - 1:
+ text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
+ text += id2docs[id].text[ss[-1]["end"] :]
+ self.report_output(
+ {
+ "evidence": (
+ ""
+ f"{id2docs[id].metadata['file_name']}
"
+ f"{text}"
+ "
"
+ )
+ }
+ )
+
+ self.report_output(None)
+ return answer
@classmethod
def get_pipeline(cls, settings, **kwargs):