From 43a18ba07096aee4136578d0564940f1631f472f Mon Sep 17 00:00:00 2001 From: ian_Cin Date: Wed, 3 Apr 2024 15:37:55 +0700 Subject: [PATCH] Feat/regenerate answer (#7) * Add regen button and repharasing question on regen * Stop appending regen messages to history, allow only one * Add dynamic conversation state * Allow reasoning pipeline to manipulate state --------- Co-authored-by: albert Co-authored-by: Duc Nguyen (john) --- libs/kotaemon/kotaemon/loaders/utils/adobe.py | 2 - libs/ktem/ktem/app.py | 1 + libs/ktem/ktem/pages/chat/__init__.py | 89 ++++++++++++++++--- libs/ktem/ktem/pages/chat/chat_panel.py | 1 + libs/ktem/ktem/pages/chat/common.py | 4 + libs/ktem/ktem/pages/chat/control.py | 6 +- libs/ktem/ktem/pages/chat/report.py | 2 + libs/ktem/ktem/reasoning/simple.py | 70 +++++++++++++-- 8 files changed, 151 insertions(+), 24 deletions(-) create mode 100644 libs/ktem/ktem/pages/chat/common.py diff --git a/libs/kotaemon/kotaemon/loaders/utils/adobe.py b/libs/kotaemon/kotaemon/loaders/utils/adobe.py index a780c45..f1adcd5 100644 --- a/libs/kotaemon/kotaemon/loaders/utils/adobe.py +++ b/libs/kotaemon/kotaemon/loaders/utils/adobe.py @@ -15,8 +15,6 @@ from decouple import config from kotaemon.loaders.utils.gpt4v import generate_gpt4v -logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO")) - def request_adobe_service(file_path: str, output_path: str = "") -> str: """Main function to call the adobe service, and unzip the results. diff --git a/libs/ktem/ktem/app.py b/libs/ktem/ktem/app.py index 0a39fa6..9bac904 100644 --- a/libs/ktem/ktem/app.py +++ b/libs/ktem/ktem/app.py @@ -17,6 +17,7 @@ class BaseApp: The main application contains app-level information: - setting state + - dynamic conversation state - user id Also contains registering methods for: diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 6648c2f..d2bba87 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -9,6 +9,7 @@ from ktem.db.models import Conversation, engine from sqlmodel import Session, select from .chat_panel import ChatPanel +from .common import STATE from .control import ConversationControl from .report import ReportIssue @@ -21,6 +22,7 @@ class ChatPage(BasePage): def on_building_ui(self): with gr.Row(): + self.chat_state = gr.State(STATE) with gr.Column(scale=1): self.chat_control = ConversationControl(self._app) @@ -62,12 +64,13 @@ class ChatPage(BasePage): self.chat_control.conversation_id, self.chat_panel.chatbot, self._app.settings_state, + self.chat_state, ] + self._indices_input, outputs=[ - self.chat_panel.text_input, self.chat_panel.chatbot, self.info_panel, + self.chat_state, ], show_progress="minimal", ).then( @@ -75,6 +78,33 @@ class ChatPage(BasePage): inputs=[ self.chat_control.conversation_id, self.chat_panel.chatbot, + self.chat_state, + ] + + self._indices_input, + outputs=None, + ) + + self.chat_panel.regen_btn.click( + fn=self.regen_fn, + inputs=[ + self.chat_control.conversation_id, + self.chat_panel.chatbot, + self._app.settings_state, + self.chat_state, + ] + + self._indices_input, + outputs=[ + self.chat_panel.chatbot, + self.info_panel, + self.chat_state, + ], + show_progress="minimal", + ).then( + fn=self.update_data_source, + inputs=[ + self.chat_control.conversation_id, + self.chat_panel.chatbot, + self.chat_state, ] + self._indices_input, outputs=None, @@ -94,6 +124,7 @@ class ChatPage(BasePage): self.chat_control.conversation, self.chat_control.conversation_rn, self.chat_panel.chatbot, + self.chat_state, ] + self._indices_input, show_progress="hidden", @@ -109,12 +140,13 @@ class ChatPage(BasePage): self.chat_panel.chatbot, self._app.settings_state, self._app.user_id, + self.chat_state, ] + self._indices_input, outputs=None, ) - def update_data_source(self, convo_id, messages, *selecteds): + def update_data_source(self, convo_id, messages, state, *selecteds): """Update the data source""" if not convo_id: gr.Warning("No conversation selected") @@ -133,6 +165,7 @@ class ChatPage(BasePage): result.data_source = { "selected": selecteds_, "messages": messages, + "state": state, "likes": deepcopy(data_source.get("likes", [])), } session.add(result) @@ -152,17 +185,22 @@ class ChatPage(BasePage): session.add(result) session.commit() - def create_pipeline(self, settings: dict, *selecteds): + def create_pipeline(self, settings: dict, state: dict, *selecteds): """Create the pipeline from settings Args: settings: the settings of the app + is_regen: whether the regen button is clicked selected: the list of file ids that will be served as context. If None, then consider using all files Returns: - the pipeline objects + - the pipeline objects """ + reasoning_mode = settings["reasoning.use"] + reasoning_cls = reasonings[reasoning_mode] + reasoning_id = reasoning_cls.get_info()["id"] + # get retrievers retrievers = [] for index in self._app.index_manager.indices: @@ -172,13 +210,17 @@ class ChatPage(BasePage): iretrievers = index.get_retriever_pipelines(settings, index_selected) retrievers += iretrievers - reasoning_mode = settings["reasoning.use"] - reasoning_cls = reasonings[reasoning_mode] - pipeline = reasoning_cls.get_pipeline(settings, retrievers) + # prepare states + reasoning_state = { + "app": deepcopy(state["app"]), + "pipeline": deepcopy(state.get(reasoning_id, {})), + } - return pipeline + pipeline = reasoning_cls.get_pipeline(settings, reasoning_state, retrievers) - async def chat_fn(self, conversation_id, chat_history, settings, *selecteds): + return pipeline, reasoning_state + + async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): """Chat function""" chat_input = chat_history[-1][0] chat_history = chat_history[:-1] @@ -186,7 +228,7 @@ class ChatPage(BasePage): queue: asyncio.Queue[Optional[dict]] = asyncio.Queue() # construct the pipeline - pipeline = self.create_pipeline(settings, *selecteds) + pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds) pipeline.set_output_queue(queue) asyncio.create_task(pipeline(chat_input, conversation_id, chat_history)) @@ -198,7 +240,8 @@ class ChatPage(BasePage): try: response = queue.get_nowait() except Exception: - yield "", chat_history + [(chat_input, text or "Thinking ...")], refs + state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] + yield chat_history + [(chat_input, text or "Thinking ...")], refs, state continue if response is None: @@ -208,6 +251,7 @@ class ChatPage(BasePage): if "output" in response: text += response["output"] + if "evidence" in response: if response["evidence"] is None: refs = "" @@ -218,4 +262,25 @@ class ChatPage(BasePage): print(f"Len refs: {len(refs)}") len_ref = len(refs) - yield "", chat_history + [(chat_input, text)], refs + state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] + yield chat_history + [(chat_input, text)], refs, state + + async def regen_fn( + self, conversation_id, chat_history, settings, state, *selecteds + ): + """Regen function""" + if not chat_history: + gr.Warning("Empty chat") + yield chat_history, "", state + return + + state["app"]["regen"] = True + async for chat, refs, state in self.chat_fn( + conversation_id, chat_history, settings, state, *selecteds + ): + new_state = deepcopy(state) + new_state["app"]["regen"] = False + yield chat, refs, new_state + else: + state["app"]["regen"] = False + yield chat_history, "", state diff --git a/libs/ktem/ktem/pages/chat/chat_panel.py b/libs/ktem/ktem/pages/chat/chat_panel.py index f4cfc5b..55b9258 100644 --- a/libs/ktem/ktem/pages/chat/chat_panel.py +++ b/libs/ktem/ktem/pages/chat/chat_panel.py @@ -19,6 +19,7 @@ class ChatPanel(BasePage): placeholder="Chat input", scale=15, container=False ) self.submit_btn = gr.Button(value="Send", scale=1, min_width=10) + self.regen_btn = gr.Button(value="Regen", scale=1, min_width=10) def submit_msg(self, chat_input, chat_history): """Submit a message to the chatbot""" diff --git a/libs/ktem/ktem/pages/chat/common.py b/libs/ktem/ktem/pages/chat/common.py new file mode 100644 index 0000000..a2fc0dc --- /dev/null +++ b/libs/ktem/ktem/pages/chat/common.py @@ -0,0 +1,4 @@ +DEFAULT_APPLICATION_STATE = {"regen": False} +STATE = { + "app": DEFAULT_APPLICATION_STATE, +} diff --git a/libs/ktem/ktem/pages/chat/control.py b/libs/ktem/ktem/pages/chat/control.py index a0b2561..e714112 100644 --- a/libs/ktem/ktem/pages/chat/control.py +++ b/libs/ktem/ktem/pages/chat/control.py @@ -5,6 +5,8 @@ from ktem.app import BasePage from ktem.db.models import Conversation, engine from sqlmodel import Session, select +from .common import STATE + logger = logging.getLogger(__name__) @@ -159,12 +161,14 @@ class ConversationControl(BasePage): name = result.name selected = result.data_source.get("selected", {}) chats = result.data_source.get("messages", []) + state = result.data_source.get("state", STATE) except Exception as e: logger.warning(e) id_ = "" name = "" selected = {} chats = [] + state = STATE indices = [] for index in self._app.index_manager.indices: @@ -173,7 +177,7 @@ class ConversationControl(BasePage): continue indices.append(selected.get(str(index.id), [])) - return id_, id_, name, chats, *indices + return id_, id_, name, chats, state, *indices def rename_conv(self, conversation_id, new_name, user_id): """Rename the conversation""" diff --git a/libs/ktem/ktem/pages/chat/report.py b/libs/ktem/ktem/pages/chat/report.py index 46d9e3c..25d83f8 100644 --- a/libs/ktem/ktem/pages/chat/report.py +++ b/libs/ktem/ktem/pages/chat/report.py @@ -48,6 +48,7 @@ class ReportIssue(BasePage): chat_history: list, settings: dict, user_id: Optional[int], + chat_state: dict, *selecteds ): selecteds_ = {} @@ -65,6 +66,7 @@ class ReportIssue(BasePage): chat={ "conv_id": conv_id, "chat_history": chat_history, + "chat_state": chat_state, "selecteds": selecteds_, }, settings=settings, diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 1627522..23d8363 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -7,7 +7,6 @@ from functools import partial import tiktoken from ktem.components import llms -from ktem.reasoning.base import BaseReasoning from theflow.settings import settings as flowsettings from kotaemon.base import ( @@ -164,6 +163,15 @@ DEFAULT_QA_FIGURE_PROMPT = ( "Answer: " ) +DEFAULT_REWRITE_PROMPT = ( + "Given the following question, rephrase and expand it " + "to help you do better answering. Maintain all information " + "in the original question. Keep the question as concise as possible. " + "Give answer in {lang}\n" + "Original question: {question}\n" + "Rephrased question: " +) + class AnswerWithContextPipeline(BaseComponent): """Answer the question based on the evidence @@ -287,15 +295,48 @@ class AnswerWithContextPipeline(BaseComponent): return answer - def extract_evidence_images(self, evidence: str): - """Util function to extract and isolate images from context/evidence""" - image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" - matches = re.findall(image_pattern, evidence) - context = re.sub(image_pattern, "", evidence) - return context, matches + +def extract_evidence_images(self, evidence: str): + """Util function to extract and isolate images from context/evidence""" + image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" + matches = re.findall(image_pattern, evidence) + context = re.sub(image_pattern, "", evidence) + return context, matches -class FullQAPipeline(BaseReasoning): +class RewriteQuestionPipeline(BaseComponent): + """Rewrite user question + + Args: + llm: the language model to rewrite question + rewrite_template: the prompt template for llm to paraphrase a text input + lang: the language of the answer. Currently support English and Japanese + """ + + llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost()) + rewrite_template: str = DEFAULT_REWRITE_PROMPT + + lang: str = "English" + + async def run(self, question: str) -> Document: # type: ignore + prompt_template = PromptTemplate(self.rewrite_template) + prompt = prompt_template.populate(question=question, lang=self.lang) + messages = [ + SystemMessage(content="You are a helpful assistant"), + HumanMessage(content=prompt), + ] + output = "" + for text in self.llm(messages): + if "content" in text: + output += text[1] + self.report_output({"chat_input": text[1]}) + break + await asyncio.sleep(0) + + return Document(text=output) + + +class FullQAPipeline(BaseComponent): """Question answering pipeline. Handle from question to answer""" class Config: @@ -305,12 +346,18 @@ class FullQAPipeline(BaseReasoning): evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() + rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx() + use_rewrite: bool = False async def run( # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore ) -> Document: # type: ignore docs = [] doc_ids = [] + if self.use_rewrite: + rewrite = await self.rewrite_pipeline(question=message) + message = rewrite.text + for retriever in self.retrievers: for doc in retriever(text=message): if doc.doc_id not in doc_ids: @@ -402,7 +449,7 @@ class FullQAPipeline(BaseReasoning): return answer @classmethod - def get_pipeline(cls, settings, retrievers): + def get_pipeline(cls, settings, states, retrievers): """Get the reasoning pipeline Args: @@ -430,6 +477,11 @@ class FullQAPipeline(BaseReasoning): pipeline.answering_pipeline.qa_template = settings[ f"reasoning.options.{_id}.qa_prompt" ] + pipeline.use_rewrite = states.get("app", {}).get("regen", False) + pipeline.rewrite_pipeline.llm = llms.get_lowest_cost() + pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( + settings["reasoning.lang"], "English" + ) return pipeline @classmethod