diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py index 6936b2a..230ce9d 100644 --- a/libs/kotaemon/kotaemon/base/component.py +++ b/libs/kotaemon/kotaemon/base/component.py @@ -39,7 +39,7 @@ class BaseComponent(Function): if isinstance(node, BaseComponent): node.set_output_queue(queue) - def report_output(self, output: Optional[dict]): + def report_output(self, output: Optional[Document]): if self._queue is not None: self._queue.put_nowait(output) diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py index 6f492c7..1a31e24 100644 --- a/libs/kotaemon/kotaemon/llms/chats/openai.py +++ b/libs/kotaemon/kotaemon/llms/chats/openai.py @@ -270,7 +270,7 @@ class ChatOpenAI(BaseChatOpenAI): def openai_response(self, client, **kwargs): """Get the openai response""" - params = { + params_ = { "model": self.model, "temperature": self.temperature, "max_tokens": self.max_tokens, @@ -285,6 +285,7 @@ class ChatOpenAI(BaseChatOpenAI): "top_logprobs": self.top_logprobs, "top_p": self.top_p, } + params = {k: v for k, v in params_.items() if v is not None} params.update(kwargs) return client.chat.completions.create(**params) diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py index 0ef64e0..71ad425 100644 --- a/libs/ktem/ktem/llms/manager.py +++ b/libs/ktem/ktem/llms/manager.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from theflow.settings import settings as flowsettings from theflow.utils.modules import deserialize -from kotaemon.base import BaseComponent +from kotaemon.llms import ChatLLM from .db import LLMTable, engine @@ -14,7 +14,7 @@ class LLMManager: """Represent a pool of models""" def __init__(self): - self._models: dict[str, BaseComponent] = {} + self._models: dict[str, ChatLLM] = {} self._info: dict[str, dict] = {} self._default: str = "" self._vendors: list[Type] = [] @@ -63,7 +63,7 @@ class LLMManager: self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM] - def __getitem__(self, key: str) -> BaseComponent: + def __getitem__(self, key: str) -> ChatLLM: """Get model by name""" return self._models[key] @@ -71,9 +71,7 @@ class LLMManager: """Check if model exists""" return key in self._models - def get( - self, key: str, default: Optional[BaseComponent] = None - ) -> Optional[BaseComponent]: + def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]: """Get model by name with default value""" return self._models.get(key, default) @@ -119,18 +117,18 @@ class LLMManager: return self._default - def get_random(self) -> BaseComponent: + def get_random(self) -> ChatLLM: """Get random model""" return self._models[self.get_random_name()] - def get_default(self) -> BaseComponent: + def get_default(self) -> ChatLLM: """Get default model In case there is no default model, choose random model from pool. In case there are multiple default models, choose random from them. Returns: - BaseComponent: model + ChatLLM: model """ return self._models[self.get_default_name()] diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 3397250..d4881d8 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -8,6 +8,7 @@ from typing import Generator import tiktoken from ktem.llms.manager import llms +from ktem.utils.render import Render from kotaemon.base import ( BaseComponent, @@ -20,7 +21,7 @@ from kotaemon.base import ( from kotaemon.indices.qa.citation import CitationPipeline from kotaemon.indices.splitters import TokenSplitter from kotaemon.llms import ChatLLM, PromptTemplate -from kotaemon.loaders.utils.gpt4v import stream_gpt4v +from kotaemon.loaders.utils.gpt4v import generate_gpt4v, stream_gpt4v from .base import BaseReasoning @@ -205,7 +206,68 @@ class AnswerWithContextPipeline(BaseComponent): system_prompt: str = "" lang: str = "English" # support English and Japanese - async def run( # type: ignore + def get_prompt(self, question, evidence, evidence_mode: int): + """Prepare the prompt and other information for LLM""" + images = [] + + if evidence_mode == EVIDENCE_MODE_TEXT: + prompt_template = PromptTemplate(self.qa_template) + elif evidence_mode == EVIDENCE_MODE_TABLE: + prompt_template = PromptTemplate(self.qa_table_template) + elif evidence_mode == EVIDENCE_MODE_FIGURE: + prompt_template = PromptTemplate(self.qa_figure_template) + else: + prompt_template = PromptTemplate(self.qa_chatbot_template) + + if evidence_mode == EVIDENCE_MODE_FIGURE: + # isolate image from evidence + evidence, images = self.extract_evidence_images(evidence) + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) + else: + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) + + return prompt, images + + def run( + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs + ) -> Document: + return self.invoke(question, evidence, evidence_mode, **kwargs) + + def invoke( + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs + ) -> Document: + prompt, images = self.get_prompt(question, evidence, evidence_mode) + + output = "" + if evidence_mode == EVIDENCE_MODE_FIGURE: + output = generate_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768) + else: + messages = [] + if self.system_prompt: + messages.append(SystemMessage(content=self.system_prompt)) + messages.append(HumanMessage(content=prompt)) + output = self.llm(messages).text + + # retrieve the citation + citation = None + if evidence and self.enable_citation: + citation = self.citation_pipeline.invoke( + context=evidence, question=question + ) + + answer = Document(text=output, metadata={"citation": citation}) + + return answer + + async def ainvoke( # type: ignore self, question: str, evidence: str, evidence_mode: int = 0, **kwargs ) -> Document: """Answer the question based on the evidence @@ -230,30 +292,7 @@ class AnswerWithContextPipeline(BaseComponent): (determined by retrieval pipeline) evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot """ - if evidence_mode == EVIDENCE_MODE_TEXT: - prompt_template = PromptTemplate(self.qa_template) - elif evidence_mode == EVIDENCE_MODE_TABLE: - prompt_template = PromptTemplate(self.qa_table_template) - elif evidence_mode == EVIDENCE_MODE_FIGURE: - prompt_template = PromptTemplate(self.qa_figure_template) - else: - prompt_template = PromptTemplate(self.qa_chatbot_template) - - images = [] - if evidence_mode == EVIDENCE_MODE_FIGURE: - # isolate image from evidence - evidence, images = self.extract_evidence_images(evidence) - prompt = prompt_template.populate( - context=evidence, - question=question, - lang=self.lang, - ) - else: - prompt = prompt_template.populate( - context=evidence, - question=question, - lang=self.lang, - ) + prompt, images = self.get_prompt(question, evidence, evidence_mode) citation_task = None if evidence and self.enable_citation: @@ -266,7 +305,7 @@ class AnswerWithContextPipeline(BaseComponent): if evidence_mode == EVIDENCE_MODE_FIGURE: for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768): output += text - self.report_output({"output": text}) + self.report_output(Document(channel="chat", content=text)) await asyncio.sleep(0) else: messages = [] @@ -279,12 +318,12 @@ class AnswerWithContextPipeline(BaseComponent): print("Trying LLM streaming") for text in self.llm.stream(messages): output += text.text - self.report_output({"output": text.text}) + self.report_output(Document(content=text.text, channel="chat")) await asyncio.sleep(0) except NotImplementedError: print("Streaming is not supported, falling back to normal processing") output = self.llm(messages).text - self.report_output({"output": output}) + self.report_output(Document(content=output, channel="chat")) # retrieve the citation print("Waiting for citation task") @@ -300,52 +339,7 @@ class AnswerWithContextPipeline(BaseComponent): def stream( # type: ignore self, question: str, evidence: str, evidence_mode: int = 0, **kwargs ) -> Generator[Document, None, Document]: - """Answer the question based on the evidence - - In addition to the question and the evidence, this method also take into - account evidence_mode. The evidence_mode tells which kind of evidence is. - The kind of evidence affects: - 1. How the evidence is represented. - 2. The prompt to generate the answer. - - By default, the evidence_mode is 0, which means the evidence is plain text with - no particular semantic representation. The evidence_mode can be: - 1. "table": There will be HTML markup telling that there is a table - within the evidence. - 2. "chatbot": There will be HTML markup telling that there is a chatbot. - This chatbot is a scenario, extracted from an Excel file, where each - row corresponds to an interaction. - - Args: - question: the original question posed by user - evidence: the text that contain relevant information to answer the question - (determined by retrieval pipeline) - evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot - """ - if evidence_mode == EVIDENCE_MODE_TEXT: - prompt_template = PromptTemplate(self.qa_template) - elif evidence_mode == EVIDENCE_MODE_TABLE: - prompt_template = PromptTemplate(self.qa_table_template) - elif evidence_mode == EVIDENCE_MODE_FIGURE: - prompt_template = PromptTemplate(self.qa_figure_template) - else: - prompt_template = PromptTemplate(self.qa_chatbot_template) - - images = [] - if evidence_mode == EVIDENCE_MODE_FIGURE: - # isolate image from evidence - evidence, images = self.extract_evidence_images(evidence) - prompt = prompt_template.populate( - context=evidence, - question=question, - lang=self.lang, - ) - else: - prompt = prompt_template.populate( - context=evidence, - question=question, - lang=self.lang, - ) + prompt, images = self.get_prompt(question, evidence, evidence_mode) output = "" if evidence_mode == EVIDENCE_MODE_FIGURE: @@ -425,37 +419,112 @@ class FullQAPipeline(BaseReasoning): rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx() use_rewrite: bool = False - async def ainvoke( # type: ignore - self, message: str, conv_id: str, history: list, **kwargs # type: ignore - ) -> Document: # type: ignore - import markdown - - docs = [] - doc_ids = [] - if self.use_rewrite: - rewrite = await self.rewrite_pipeline(question=message) - message = rewrite.text - + def retrieve(self, message: str) -> tuple[list[RetrievedDocument], list[Document]]: + """Retrieve the documents based on the message""" + docs, doc_ids = [], [] for retriever in self.retrievers: for doc in retriever(text=message): if doc.doc_id not in doc_ids: docs.append(doc) doc_ids.append(doc.doc_id) + + info = [] for doc in docs: - # TODO: a better approach to show the information - text = markdown.markdown( - doc.text, extensions=["markdown.extensions.tables"] + info.append( + Document( + channel="info", + content=Render.collapsible( + header=doc.metadata["file_name"], + content=Render.table(doc.text), + open=True, + ), + ) ) - self.report_output( - { - "evidence": ( - "
" - f"{doc.metadata['file_name']}" - f"{text}" - "

" - ) - } + + return docs, info + + def prepare_citations(self, answer, docs) -> tuple[list[Document], list[Document]]: + """Prepare the citations to show on the UI""" + with_citation, without_citation = [], [] + spans = defaultdict(list) + + if answer.metadata["citation"] is not None: + for fact_with_evidence in answer.metadata["citation"].answer: + for quote in fact_with_evidence.substring_quote: + for doc in docs: + start_idx = doc.text.find(quote) + if start_idx == -1: + continue + + end_idx = start_idx + len(quote) + + current_idx = start_idx + if "|" not in doc.text[start_idx:end_idx]: + spans[doc.doc_id].append( + {"start": start_idx, "end": end_idx} + ) + else: + while doc.text[current_idx:end_idx].find("|") != -1: + match_idx = doc.text[current_idx:end_idx].find("|") + spans[doc.doc_id].append( + { + "start": current_idx, + "end": current_idx + match_idx, + } + ) + current_idx += match_idx + 2 + if current_idx > end_idx: + break + break + + id2docs = {doc.doc_id: doc for doc in docs} + not_detected = set(id2docs.keys()) - set(spans.keys()) + for id, ss in spans.items(): + if not ss: + not_detected.add(id) + continue + ss = sorted(ss, key=lambda x: x["start"]) + text = id2docs[id].text[: ss[0]["start"]] + for idx, span in enumerate(ss): + text += Render.highlight(id2docs[id].text[span["start"] : span["end"]]) + if idx < len(ss) - 1: + text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] + text += id2docs[id].text[ss[-1]["end"] :] + with_citation.append( + Document( + channel="info", + content=Render.collapsible( + header=id2docs[id].metadata["file_name"], + content=Render.table(text), + open=True, + ), + ) ) + + without_citation = [ + Document( + channel="info", + content=Render.collapsible( + header=id2docs[id].metadata["file_name"], + content=Render.table(id2docs[id].text), + open=False, + ), + ) + for id in list(not_detected) + ] + + return with_citation, without_citation + + async def ainvoke( # type: ignore + self, message: str, conv_id: str, history: list, **kwargs # type: ignore + ) -> Document: # type: ignore + if self.use_rewrite: + rewrite = await self.rewrite_pipeline(question=message) + message = rewrite.text + + docs, infos = self.retrieve(message) + for _ in infos: + self.report_output(_) await asyncio.sleep(0.1) evidence_mode, evidence = self.evidence_pipeline(docs).content @@ -468,90 +537,23 @@ class FullQAPipeline(BaseReasoning): **kwargs, ) - # prepare citation - spans = defaultdict(list) - if answer.metadata["citation"] is not None: - for fact_with_evidence in answer.metadata["citation"].answer: - for quote in fact_with_evidence.substring_quote: - for doc in docs: - start_idx = doc.text.find(quote) - if start_idx == -1: - continue - - end_idx = start_idx + len(quote) - - current_idx = start_idx - if "|" not in doc.text[start_idx:end_idx]: - spans[doc.doc_id].append( - {"start": start_idx, "end": end_idx} - ) - else: - while doc.text[current_idx:end_idx].find("|") != -1: - match_idx = doc.text[current_idx:end_idx].find("|") - spans[doc.doc_id].append( - { - "start": current_idx, - "end": current_idx + match_idx, - } - ) - current_idx += match_idx + 2 - if current_idx > end_idx: - break - break - - id2docs = {doc.doc_id: doc for doc in docs} - lack_evidence = True - not_detected = set(id2docs.keys()) - set(spans.keys()) - self.report_output({"evidence": None}) - for id, ss in spans.items(): - if not ss: - not_detected.add(id) - continue - ss = sorted(ss, key=lambda x: x["start"]) - text = id2docs[id].text[: ss[0]["start"]] - for idx, span in enumerate(ss): - text += ( - "" + id2docs[id].text[span["start"] : span["end"]] + "" - ) - if idx < len(ss) - 1: - text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] - text += id2docs[id].text[ss[-1]["end"] :] - text_out = markdown.markdown( - text, extensions=["markdown.extensions.tables"] - ) - self.report_output( - { - "evidence": ( - "
" - f"{id2docs[id].metadata['file_name']}" - f"{text_out}" - "

" - ) - } - ) - lack_evidence = False - - if lack_evidence: - self.report_output({"evidence": "No evidence found.\n"}) - - if not_detected: - self.report_output( - {"evidence": "Retrieved segments without matching evidence:\n"} - ) - for id in list(not_detected): - text_out = markdown.markdown( - id2docs[id].text, extensions=["markdown.extensions.tables"] - ) + # show the evidence + with_citation, without_citation = self.prepare_citations(answer, docs) + if not with_citation and not without_citation: + self.report_output(Document(channel="info", content="No evidence found.\n")) + else: + self.report_output(Document(channel="info", content=None)) + for _ in with_citation: + self.report_output(_) + if without_citation: self.report_output( - { - "evidence": ( - "
" - f"{id2docs[id].metadata['file_name']}" - f"{text_out}" - "

" - ) - } + Document( + channel="info", + content="Retrieved segments without matching evidence:\n", + ) ) + for _ in without_citation: + self.report_output(_) self.report_output(None) return answer @@ -559,32 +561,12 @@ class FullQAPipeline(BaseReasoning): def stream( # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore ) -> Generator[Document, None, Document]: - import markdown - - docs = [] - doc_ids = [] if self.use_rewrite: message = self.rewrite_pipeline(question=message).text - for retriever in self.retrievers: - for doc in retriever(text=message): - if doc.doc_id not in doc_ids: - docs.append(doc) - doc_ids.append(doc.doc_id) - for doc in docs: - # TODO: a better approach to show the information - text = markdown.markdown( - doc.text, extensions=["markdown.extensions.tables"] - ) - yield Document( - content=( - "
" - f"{doc.metadata['file_name']}" - f"{text}" - "

" - ), - channel="info", - ) + docs, infos = self.retrieve(message) + for _ in infos: + yield _ evidence_mode, evidence = self.evidence_pipeline(docs).content answer = yield from self.answering_pipeline.stream( @@ -596,89 +578,21 @@ class FullQAPipeline(BaseReasoning): **kwargs, ) - # prepare citation - spans = defaultdict(list) - if answer.metadata["citation"] is not None: - for fact_with_evidence in answer.metadata["citation"].answer: - for quote in fact_with_evidence.substring_quote: - for doc in docs: - start_idx = doc.text.find(quote) - if start_idx == -1: - continue - - end_idx = start_idx + len(quote) - - current_idx = start_idx - if "|" not in doc.text[start_idx:end_idx]: - spans[doc.doc_id].append( - {"start": start_idx, "end": end_idx} - ) - else: - while doc.text[current_idx:end_idx].find("|") != -1: - match_idx = doc.text[current_idx:end_idx].find("|") - spans[doc.doc_id].append( - { - "start": current_idx, - "end": current_idx + match_idx, - } - ) - current_idx += match_idx + 2 - if current_idx > end_idx: - break - break - - id2docs = {doc.doc_id: doc for doc in docs} - lack_evidence = True - not_detected = set(id2docs.keys()) - set(spans.keys()) - yield Document(channel="info", content=None) - for id, ss in spans.items(): - if not ss: - not_detected.add(id) - continue - ss = sorted(ss, key=lambda x: x["start"]) - text = id2docs[id].text[: ss[0]["start"]] - for idx, span in enumerate(ss): - text += ( - "" + id2docs[id].text[span["start"] : span["end"]] + "" - ) - if idx < len(ss) - 1: - text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] - text += id2docs[id].text[ss[-1]["end"] :] - text_out = markdown.markdown( - text, extensions=["markdown.extensions.tables"] - ) - yield Document( - content=( - "
" - f"{id2docs[id].metadata['file_name']}" - f"{text_out}" - "

" - ), - channel="info", - ) - lack_evidence = False - - if lack_evidence: + # show the evidence + with_citation, without_citation = self.prepare_citations(answer, docs) + if not with_citation and not without_citation: yield Document(channel="info", content="No evidence found.\n") - - if not_detected: - yield Document( - channel="info", - content="Retrieved segments without matching evidence:\n", - ) - for id in list(not_detected): - text_out = markdown.markdown( - id2docs[id].text, extensions=["markdown.extensions.tables"] - ) + else: + yield Document(channel="info", content=None) + for _ in with_citation: + yield _ + if without_citation: yield Document( - content=( - "
" - f"{id2docs[id].metadata['file_name']}" - f"{text_out}" - "

" - ), channel="info", + content="Retrieved segments without matching evidence:\n", ) + for _ in without_citation: + yield _ return answer diff --git a/libs/ktem/ktem/utils/__init__.py b/libs/ktem/ktem/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py new file mode 100644 index 0000000..5890d33 --- /dev/null +++ b/libs/ktem/ktem/utils/render.py @@ -0,0 +1,21 @@ +import markdown + + +class Render: + """Default text rendering into HTML for the UI""" + + @staticmethod + def collapsible(header, content, open: bool = False) -> str: + """Render an HTML friendly collapsible section""" + o = " open" if open else "" + return f"{header}{content}
" + + @staticmethod + def table(text: str) -> str: + """Render table from markdown format into HTML""" + return markdown.markdown(text, extensions=["markdown.extensions.tables"]) + + @staticmethod + def highlight(text: str) -> str: + """Highlight text""" + return f"{text}"