diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py index 4e6f7b8..9acd39f 100644 --- a/libs/kotaemon/kotaemon/base/component.py +++ b/libs/kotaemon/kotaemon/base/component.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Iterator, Optional +from typing import AsyncGenerator, Iterator, Optional from theflow import Function, Node, Param, lazy @@ -43,6 +43,18 @@ class BaseComponent(Function): if self._queue is not None: self._queue.put_nowait(output) + def invoke(self, *args, **kwargs) -> Document | list[Document] | None: + ... + + async def ainvoke(self, *args, **kwargs) -> Document | list[Document] | None: + ... + + def stream(self, *args, **kwargs) -> Iterator[Document] | None: + ... + + async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None: + ... + @abstractmethod def run( self, *args, **kwargs diff --git a/libs/kotaemon/kotaemon/indices/qa/citation.py b/libs/kotaemon/kotaemon/indices/qa/citation.py index 4c1281a..4fe8600 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation.py @@ -65,6 +65,9 @@ class CitationPipeline(BaseComponent): llm: BaseLLM def run(self, context: str, question: str): + return self.invoke(context, question) + + def prepare_llm(self, context: str, question: str): schema = QuestionAnswer.schema() function = { "name": schema["title"], @@ -92,8 +95,37 @@ class CitationPipeline(BaseComponent): ) ), ] + return messages, llm_kwargs + + def invoke(self, context: str, question: str): + messages, llm_kwargs = self.prepare_llm(context, question) + + try: + print("CitationPipeline: invoking LLM") + llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs) + print("CitationPipeline: finish invoking LLM") + except Exception as e: + print(e) + return None + + function_output = llm_output.messages[0].additional_kwargs["function_call"][ + "arguments" + ] + output = QuestionAnswer.parse_raw(function_output) + + return output + + async def ainvoke(self, context: str, question: str): + messages, llm_kwargs = self.prepare_llm(context, question) + + try: + print("CitationPipeline: async invoking LLM") + llm_output = await self.get_from_path("llm").ainvoke(messages, **llm_kwargs) + print("CitationPipeline: finish async invoking LLM") + except Exception as e: + print(e) + return None - llm_output = self.llm(messages, **llm_kwargs) function_output = llm_output.messages[0].additional_kwargs["function_call"][ "arguments" ] diff --git a/libs/kotaemon/kotaemon/llms/base.py b/libs/kotaemon/kotaemon/llms/base.py index ff315ea..6ef7afc 100644 --- a/libs/kotaemon/kotaemon/llms/base.py +++ b/libs/kotaemon/kotaemon/llms/base.py @@ -1,8 +1,22 @@ +from typing import AsyncGenerator, Iterator + from langchain_core.language_models.base import BaseLanguageModel -from kotaemon.base import BaseComponent +from kotaemon.base import BaseComponent, LLMInterface class BaseLLM(BaseComponent): def to_langchain_format(self) -> BaseLanguageModel: raise NotImplementedError + + def invoke(self, *args, **kwargs) -> LLMInterface: + raise NotImplementedError + + async def ainvoke(self, *args, **kwargs) -> LLMInterface: + raise NotImplementedError + + def stream(self, *args, **kwargs) -> Iterator[LLMInterface]: + raise NotImplementedError + + async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]: + raise NotImplementedError diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py index c5c2469..14064ba 100644 --- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from typing import AsyncGenerator, Iterator from kotaemon.base import BaseMessage, HumanMessage, LLMInterface @@ -10,6 +11,8 @@ logger = logging.getLogger(__name__) class LCChatMixin: + """Mixin for langchain based chat models""" + def _get_lc_class(self): raise NotImplementedError( "Please return the relevant Langchain class in in _get_lc_class" @@ -30,18 +33,7 @@ class LCChatMixin: return self.stream(messages, **kwargs) # type: ignore return self.invoke(messages, **kwargs) - def invoke( - self, messages: str | BaseMessage | list[BaseMessage], **kwargs - ) -> LLMInterface: - """Generate response from messages - - Args: - messages: history of messages to generate response from - **kwargs: additional arguments to pass to the langchain chat model - - Returns: - LLMInterface: generated response - """ + def prepare_message(self, messages: str | BaseMessage | list[BaseMessage]): input_: list[BaseMessage] = [] if isinstance(messages, str): @@ -51,7 +43,9 @@ class LCChatMixin: else: input_ = messages - pred = self._obj.generate(messages=[input_], **kwargs) + return input_ + + def prepare_response(self, pred): all_text = [each.text for each in pred.generations[0]] all_messages = [each.message for each in pred.generations[0]] @@ -76,10 +70,41 @@ class LCChatMixin: logits=[], ) - def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs): + def invoke( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> LLMInterface: + """Generate response from messages + + Args: + messages: history of messages to generate response from + **kwargs: additional arguments to pass to the langchain chat model + + Returns: + LLMInterface: generated response + """ + input_ = self.prepare_message(messages) + pred = self._obj.generate(messages=[input_], **kwargs) + return self.prepare_response(pred) + + async def ainvoke( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> LLMInterface: + input_ = self.prepare_message(messages) + pred = await self._obj.agenerate(messages=[input_], **kwargs) + return self.prepare_response(pred) + + def stream( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> Iterator[LLMInterface]: for response in self._obj.stream(input=messages, **kwargs): yield LLMInterface(content=response.content) + async def astream( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> AsyncGenerator[LLMInterface, None]: + async for response in self._obj.astream(input=messages, **kwargs): + yield LLMInterface(content=response.content) + def to_langchain_format(self): return self._obj @@ -140,7 +165,7 @@ class LCChatMixin: raise ValueError(f"Invalid param {path}") -class AzureChatOpenAI(LCChatMixin, ChatLLM): +class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore def __init__( self, azure_endpoint: str | None = None, diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index b06795d..6648c2f 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -209,7 +209,10 @@ class ChatPage(BasePage): if "output" in response: text += response["output"] if "evidence" in response: - refs += response["evidence"] + if response["evidence"] is None: + refs = "" + else: + refs += response["evidence"] if len(refs) > len_ref: print(f"Len refs: {len(refs)}") diff --git a/libs/ktem/ktem/reasoning/base.py b/libs/ktem/ktem/reasoning/base.py index c122dfa..80cf016 100644 --- a/libs/ktem/ktem/reasoning/base.py +++ b/libs/ktem/ktem/reasoning/base.py @@ -1,5 +1,49 @@ +from typing import Optional + from kotaemon.base import BaseComponent class BaseReasoning(BaseComponent): - retrievers: list = [] + """The reasoning pipeline that handles each of the user chat messages + + This reasoning pipeline has access to: + - the retrievers + - the user settings + - the message + - the conversation id + - the message history + """ + + @classmethod + def get_info(cls) -> dict: + """Get the pipeline information for the app to organize and display + + Returns: + a dictionary that contains the following keys: + - "id": the unique id of the pipeline + - "name": the human-friendly name of the pipeline + - "description": the overview short description of the pipeline, for + user to grasp what does the pipeline do + """ + raise NotImplementedError + + @classmethod + def get_user_settings(cls) -> dict: + """Get the default user settings for this pipeline""" + return {} + + @classmethod + def get_pipeline( + cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None + ) -> "BaseReasoning": + """Get the reasoning pipeline for the app to execute + + Args: + user_setting: user settings + retrievers (list): List of retrievers + """ + return cls() + + def run(self, message: str, conv_id: str, history: list, **kwargs): # type: ignore + """Execute the reasoning pipeline""" + raise NotImplementedError diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index c8653a7..acd768f 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -200,22 +200,24 @@ class AnswerWithContextPipeline(BaseComponent): lang=self.lang, ) + citation_task = asyncio.create_task( + self.citation_pipeline.ainvoke(context=evidence, question=question) + ) + print("Citation task created") + messages = [] if self.system_prompt: messages.append(SystemMessage(content=self.system_prompt)) messages.append(HumanMessage(content=prompt)) output = "" - for text in self.llm(messages): + for text in self.llm.stream(messages): output += text.text self.report_output({"output": text.text}) await asyncio.sleep(0) - try: - citation = self.citation_pipeline(context=evidence, question=question) - except Exception as e: - print(e) - citation = None - + # retrieve the citation + print("Waiting for citation task") + citation = await citation_task answer = Document(text=output, metadata={"citation": citation}) return answer @@ -242,6 +244,19 @@ class FullQAPipeline(BaseReasoning): if doc.doc_id not in doc_ids: docs.append(doc) doc_ids.append(doc.doc_id) + for doc in docs: + self.report_output( + { + "evidence": ( + "
" + f"{doc.metadata['file_name']}" + f"{doc.text}" + "

" + ) + } + ) + await asyncio.sleep(0.1) + evidence_mode, evidence = self.evidence_pipeline(docs).content answer = await self.answering_pipeline( question=message, evidence=evidence, evidence_mode=evidence_mode @@ -266,6 +281,7 @@ class FullQAPipeline(BaseReasoning): id2docs = {doc.doc_id: doc for doc in docs} lack_evidence = True not_detected = set(id2docs.keys()) - set(spans.keys()) + self.report_output({"evidence": None}) for id, ss in spans.items(): if not ss: not_detected.add(id) @@ -282,7 +298,7 @@ class FullQAPipeline(BaseReasoning): self.report_output( { "evidence": ( - "
" + "
" f"{id2docs[id].metadata['file_name']}" f"{text}" "

"