diff --git a/libs/kotaemon/kotaemon/indices/ingests/files.py b/libs/kotaemon/kotaemon/indices/ingests/files.py index 75f944e..0f80cb2 100644 --- a/libs/kotaemon/kotaemon/indices/ingests/files.py +++ b/libs/kotaemon/kotaemon/indices/ingests/files.py @@ -47,6 +47,8 @@ class DocumentIngestor(BaseComponent): text_splitter: BaseSplitter = TokenSplitter.withx( chunk_size=1024, chunk_overlap=256, + separator="\n\n", + backup_separators=["\n", ".", " ", "\u200B"], ) override_file_extractors: dict[str, Type[BaseReader]] = {} diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index d4881d8..1643c50 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -11,6 +11,7 @@ from ktem.llms.manager import llms from ktem.utils.render import Render from kotaemon.base import ( + AIMessage, BaseComponent, Document, HumanMessage, @@ -205,6 +206,7 @@ class AnswerWithContextPipeline(BaseComponent): enable_citation: bool = False system_prompt: str = "" lang: str = "English" # support English and Japanese + n_last_interactions: int = 5 def get_prompt(self, question, evidence, evidence_mode: int): """Prepare the prompt and other information for LLM""" @@ -244,6 +246,7 @@ class AnswerWithContextPipeline(BaseComponent): def invoke( self, question: str, evidence: str, evidence_mode: int = 0, **kwargs ) -> Document: + history = kwargs.get("history", []) prompt, images = self.get_prompt(question, evidence, evidence_mode) output = "" @@ -253,6 +256,9 @@ class AnswerWithContextPipeline(BaseComponent): messages = [] if self.system_prompt: messages.append(SystemMessage(content=self.system_prompt)) + for human, ai in history[-self.n_last_interactions :]: + messages.append(HumanMessage(content=human)) + messages.append(AIMessage(content=ai)) messages.append(HumanMessage(content=prompt)) output = self.llm(messages).text @@ -292,6 +298,7 @@ class AnswerWithContextPipeline(BaseComponent): (determined by retrieval pipeline) evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot """ + history = kwargs.get("history", []) prompt, images = self.get_prompt(question, evidence, evidence_mode) citation_task = None @@ -311,6 +318,9 @@ class AnswerWithContextPipeline(BaseComponent): messages = [] if self.system_prompt: messages.append(SystemMessage(content=self.system_prompt)) + for human, ai in history[-self.n_last_interactions :]: + messages.append(HumanMessage(content=human)) + messages.append(AIMessage(content=ai)) messages.append(HumanMessage(content=prompt)) try: @@ -339,6 +349,7 @@ class AnswerWithContextPipeline(BaseComponent): def stream( # type: ignore self, question: str, evidence: str, evidence_mode: int = 0, **kwargs ) -> Generator[Document, None, Document]: + history = kwargs.get("history", []) prompt, images = self.get_prompt(question, evidence, evidence_mode) output = "" @@ -350,6 +361,9 @@ class AnswerWithContextPipeline(BaseComponent): messages = [] if self.system_prompt: messages.append(SystemMessage(content=self.system_prompt)) + for human, ai in history[-self.n_last_interactions :]: + messages.append(HumanMessage(content=human)) + messages.append(AIMessage(content=ai)) messages.append(HumanMessage(content=prompt)) try: @@ -406,6 +420,50 @@ class RewriteQuestionPipeline(BaseComponent): return self.llm(messages) +class AddQueryContextPipeline(BaseComponent): + + n_last_interactions: int = 5 + llm: ChatLLM = Node(default_callback=lambda _: llms.get_default()) + + def run(self, question: str, history: list) -> Document: + messages = [ + SystemMessage( + content="Below is a history of the conversation so far, and a new " + "question asked by the user that needs to be answered by searching " + "in a knowledge base.\nYou have access to a Search index " + "with 100's of documents.\nGenerate a search query based on the " + "conversation and the new question.\nDo not include cited source " + "filenames and document names e.g info.txt or doc.pdf in the search " + "query terms.\nDo not include any text inside [] or <<>> in the " + "search query terms.\nDo not include any special characters like " + "'+'.\nIf the question is not in English, rewrite the query in " + "the language used in the question.\n If the question contains enough " + "information, return just the number 1\n If it's unnecessary to do " + "the searching, return just the number 0." + ), + HumanMessage(content="How did crypto do last year?"), + AIMessage( + content="Summarize Cryptocurrency Market Dynamics from last year" + ), + HumanMessage(content="What are my health plans?"), + AIMessage(content="Show available health plans"), + ] + for human, ai in history[-self.n_last_interactions :]: + messages.append(HumanMessage(content=human)) + messages.append(AIMessage(content=ai)) + + messages.append(HumanMessage(content=f"Generate search query for: {question}")) + + resp = self.llm(messages).text + if resp == "0": + return Document(content="") + + if resp == "1": + return Document(content=question) + + return Document(content=resp) + + class FullQAPipeline(BaseReasoning): """Question answering pipeline. Handle from question to answer""" @@ -417,13 +475,29 @@ class FullQAPipeline(BaseReasoning): evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx() + add_query_context: AddQueryContextPipeline = AddQueryContextPipeline.withx() + trigger_context: int = 150 use_rewrite: bool = False - def retrieve(self, message: str) -> tuple[list[RetrievedDocument], list[Document]]: + def retrieve( + self, message: str, history: list + ) -> tuple[list[RetrievedDocument], list[Document]]: """Retrieve the documents based on the message""" + if len(message) < self.trigger_context: + # prefer adding context for short user questions, avoid adding context for + # long questions, as they are likely to contain enough information + # plus, avoid the situation where the original message is already too long + # for the model to handle + query = self.add_query_context(message, history).content + else: + query = message + print(f"Rewritten query: {query}") + if not query: + return [], [] + docs, doc_ids = [], [] for retriever in self.retrievers: - for doc in retriever(text=message): + for doc in retriever(text=query): if doc.doc_id not in doc_ids: docs.append(doc) doc_ids.append(doc.doc_id) @@ -522,7 +596,7 @@ class FullQAPipeline(BaseReasoning): rewrite = await self.rewrite_pipeline(question=message) message = rewrite.text - docs, infos = self.retrieve(message) + docs, infos = self.retrieve(message, history) for _ in infos: self.report_output(_) await asyncio.sleep(0.1) @@ -564,7 +638,8 @@ class FullQAPipeline(BaseReasoning): if self.use_rewrite: message = self.rewrite_pipeline(question=message).text - docs, infos = self.retrieve(message) + # should populate the context + docs, infos = self.retrieve(message, history) for _ in infos: yield _ @@ -604,24 +679,27 @@ class FullQAPipeline(BaseReasoning): settings: the settings for the pipeline retrievers: the retrievers to use """ - _id = cls.get_info()["id"] - + prefix = f"reasoning.options.{cls.get_info()['id']}" pipeline = FullQAPipeline(retrievers=retrievers) - pipeline.answering_pipeline.llm = llms.get_default() - pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default() - pipeline.answering_pipeline.enable_citation = settings[ - f"reasoning.options.{_id}.highlight_citation" - ] - pipeline.answering_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( + # answering pipeline configuration + answer_pipeline = pipeline.answering_pipeline + answer_pipeline.llm = llms.get_default() + answer_pipeline.citation_pipeline.llm = llms.get_default() + answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"] + answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"] + answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"] + answer_pipeline.qa_template = settings[f"{prefix}.qa_prompt"] + answer_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( settings["reasoning.lang"], "English" ) - pipeline.answering_pipeline.system_prompt = settings[ - f"reasoning.options.{_id}.system_prompt" - ] - pipeline.answering_pipeline.qa_template = settings[ - f"reasoning.options.{_id}.qa_prompt" + + pipeline.add_query_context.llm = llms.get_default() + pipeline.add_query_context.n_last_interactions = settings[ + f"{prefix}.n_last_interactions" ] + + pipeline.trigger_context = settings[f"{prefix}.trigger_context"] pipeline.use_rewrite = states.get("app", {}).get("regen", False) pipeline.rewrite_pipeline.llm = llms.get_default() pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( @@ -645,6 +723,21 @@ class FullQAPipeline(BaseReasoning): "name": "QA Prompt (contains {context}, {question}, {lang})", "value": DEFAULT_QA_TEXT_PROMPT, }, + "n_last_interactions": { + "name": "Number of interactions to include", + "value": 5, + "component": "number", + "info": "The maximum number of chat interactions to include in the LLM", + }, + "trigger_context": { + "name": "Maximum message length for context rewriting", + "value": 150, + "component": "number", + "info": ( + "The maximum length of the message to trigger context addition. " + "Exceeding this length, the message will be used as is." + ), + }, } @classmethod