diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py index 8c24c14..2f8abf5 100644 --- a/libs/kotaemon/kotaemon/base/component.py +++ b/libs/kotaemon/kotaemon/base/component.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Iterator +from typing import Iterator, Optional from kotaemon.base.schema import Document from theflow import Function, Node, Param, lazy @@ -31,6 +31,17 @@ class BaseComponent(Function): return self.__call__(self.inflow.flow()) + def set_output_queue(self, queue): + self._queue = queue + for name in self._ff_nodes: + node = getattr(self, name) + if isinstance(node, BaseComponent): + node.set_output_queue(queue) + + def report_output(self, output: Optional[dict]): + if self._queue is not None: + self._queue.put_nowait(output) + @abstractmethod def run( self, *args, **kwargs diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py index 74e5cbc..dbf19ba 100644 --- a/libs/ktem/flowsettings.py +++ b/libs/ktem/flowsettings.py @@ -11,7 +11,7 @@ user_cache_dir.mkdir(parents=True, exist_ok=True) COHERE_API_KEY = config("COHERE_API_KEY", default="") -# KH_MODE = "dev" +KH_MODE = "dev" KH_DATABASE = f"sqlite:///{user_cache_dir / 'sql.db'}" KH_DOCSTORE = { "__type__": "kotaemon.storages.SimpleFileDocumentStore", diff --git a/libs/ktem/ktem/assets/css/main.css b/libs/ktem/ktem/assets/css/main.css index 15c2aa2..0ed480a 100644 --- a/libs/ktem/ktem/assets/css/main.css +++ b/libs/ktem/ktem/assets/css/main.css @@ -32,6 +32,15 @@ footer { height: calc(100vh - 140px) !important; } +#chat-info-panel { + max-height: calc(100vh - 140px) !important; + overflow-y: scroll !important; +} + .setting-answer-mode-description { margin: 5px 5px 2px !important } + +mark { + background-color: #1496bb; +} diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index e148c1d..8bd4217 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -23,6 +23,9 @@ class ChatPage(BasePage): self.report_issue = ReportIssue(self._app) with gr.Column(scale=6): self.chat_panel = ChatPanel(self._app) + with gr.Column(scale=3): + with gr.Accordion(label="Information panel", open=True): + self.info_panel = gr.Markdown(elem_id="chat-info-panel") def on_register_events(self): self.chat_panel.submit_btn.click( @@ -33,7 +36,11 @@ class ChatPage(BasePage): self.data_source.files, self._app.settings_state, ], - outputs=[self.chat_panel.text_input, self.chat_panel.chatbot], + outputs=[ + self.chat_panel.text_input, + self.chat_panel.chatbot, + self.info_panel, + ], ).then( fn=update_data_source, inputs=[ @@ -52,7 +59,11 @@ class ChatPage(BasePage): self.data_source.files, self._app.settings_state, ], - outputs=[self.chat_panel.text_input, self.chat_panel.chatbot], + outputs=[ + self.chat_panel.text_input, + self.chat_panel.chatbot, + self.info_panel, + ], ).then( fn=update_data_source, inputs=[ diff --git a/libs/ktem/ktem/pages/chat/events.py b/libs/ktem/ktem/pages/chat/events.py index 9a3346d..18f6506 100644 --- a/libs/ktem/ktem/pages/chat/events.py +++ b/libs/ktem/ktem/pages/chat/events.py @@ -1,3 +1,4 @@ +import asyncio import os import tempfile from copy import deepcopy @@ -116,24 +117,33 @@ def create_pipeline(settings: dict, files: Optional[list] = None): return pipeline -def chat_fn(chat_input, chat_history, files, settings): +async def chat_fn(chat_input, chat_history, files, settings): + """Chat function""" + queue: asyncio.Queue[Optional[dict]] = asyncio.Queue() + + # construct the pipeline pipeline = create_pipeline(settings, files) + pipeline.set_output_queue(queue) - text = "" - refs = [] - for response in pipeline(chat_input): - if response.metadata.get("citation", None): - citation = response.metadata["citation"] - for idx, fact_with_evidence in enumerate(citation.answer): - quotes = fact_with_evidence.substring_quote - if quotes: - refs.append( - (None, f"***Reference {idx+1}***: {' ... '.join(quotes)}") - ) - else: - text += response.text + asyncio.create_task(pipeline(chat_input)) + text, refs = "", "" - yield "", chat_history + [(chat_input, text)] + refs + while True: + try: + response = queue.get_nowait() + except Exception: + await asyncio.sleep(0) + continue + + if response is None: + break + + if "output" in response: + text += response["output"] + if "evidence" in response: + refs += response["evidence"] + + yield "", chat_history + [(chat_input, text)], refs def is_liked(convo_id, liked: gr.LikeData): diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 5cfc819..cab5063 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -1,3 +1,4 @@ +import asyncio import logging import warnings from collections import defaultdict @@ -275,7 +276,7 @@ class AnswerWithContextPipeline(BaseComponent): system_prompt: str = "" lang: str = "English" # support English and Japanese - def run( + async def run( # type: ignore self, question: str, evidence: str, evidence_mode: int = 0 ) -> Document | Iterator[Document]: """Answer the question based on the evidence @@ -318,12 +319,16 @@ class AnswerWithContextPipeline(BaseComponent): SystemMessage(content="You are a helpful assistant"), HumanMessage(content=prompt), ] - # output = self.llm(messages).text - yield from self.llm(messages) + output = "" + for text in self.llm(messages): + output += text.text + self.report_output({"output": text.text}) + await asyncio.sleep(0) citation = self.citation_pipeline(context=evidence, question=question) - answer = Document(text="", metadata={"citation": citation}) - yield answer + answer = Document(text=output, metadata={"citation": citation}) + + return answer class FullQAPipeline(BaseComponent): @@ -337,13 +342,56 @@ class FullQAPipeline(BaseComponent): evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() - def run(self, question: str, **kwargs) -> Iterator[Document]: + async def run(self, question: str, **kwargs) -> Document: # type: ignore docs = self.retrieval_pipeline(text=question) evidence_mode, evidence = self.evidence_pipeline(docs).content - answer = self.answering_pipeline( + answer = await self.answering_pipeline( question=question, evidence=evidence, evidence_mode=evidence_mode ) - yield from answer # should be a generator + + # prepare citation + from collections import defaultdict + + spans = defaultdict(list) + for fact_with_evidence in answer.metadata["citation"].answer: + for quote in fact_with_evidence.substring_quote: + for doc in docs: + start_idx = doc.text.find(quote) + if start_idx >= 0: + spans[doc.doc_id].append( + { + "start": start_idx, + "end": start_idx + len(quote), + } + ) + break + + id2docs = {doc.doc_id: doc for doc in docs} + for id, ss in spans.items(): + if not ss: + continue + ss = sorted(ss, key=lambda x: x["start"]) + text = id2docs[id].text[: ss[0]["start"]] + for idx, span in enumerate(ss): + text += ( + "" + id2docs[id].text[span["start"] : span["end"]] + "" + ) + if idx < len(ss) - 1: + text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] + text += id2docs[id].text[ss[-1]["end"] :] + self.report_output( + { + "evidence": ( + "
" + f"{id2docs[id].metadata['file_name']}" + f"{text}" + "

" + ) + } + ) + + self.report_output(None) + return answer @classmethod def get_pipeline(cls, settings, **kwargs):