From 49a083fd9f1bd93748b12fc2f20f240a5fa39f78 Mon Sep 17 00:00:00 2001 From: KennyWu Date: Thu, 10 Oct 2024 12:02:04 +0800 Subject: [PATCH] feat: tweak the 'Chat suggestion' feature to tie it to conversations (#341) #none Signed-off-by: Kennywu --- flowsettings.py | 1 + libs/ktem/ktem/db/base_models.py | 2 +- libs/ktem/ktem/pages/chat/__init__.py | 205 +++++++++++++++++- libs/ktem/ktem/pages/chat/control.py | 50 +++++ .../suggest_followup_chat.py | 45 ++++ 5 files changed, 295 insertions(+), 8 deletions(-) create mode 100644 libs/ktem/ktem/reasoning/prompt_optimization/suggest_followup_chat.py diff --git a/flowsettings.py b/flowsettings.py index 28ad9d2..ecc7580 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -63,6 +63,7 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface") KH_DOC_DIR = this_dir / "docs" KH_MODE = "dev" +KH_FEATURE_CHAT_SUGGESTION = config("KH_FEATURE_CHAT_SUGGESTION", default=False) KH_FEATURE_USER_MANAGEMENT = config( "KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool ) diff --git a/libs/ktem/ktem/db/base_models.py b/libs/ktem/ktem/db/base_models.py index 1379caf..7d8b3e5 100644 --- a/libs/ktem/ktem/db/base_models.py +++ b/libs/ktem/ktem/db/base_models.py @@ -34,7 +34,7 @@ class BaseConversation(SQLModel): is_public: bool = Field(default=False) - # contains messages + current files + # contains messages + current files + chat_suggestions data_source: dict = Field(default={}, sa_column=Column(JSON)) date_created: datetime.datetime = Field(default_factory=datetime.datetime.utcnow) diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index f1ef015..50beaf3 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -1,8 +1,14 @@ +import ast import asyncio +import csv +import re from copy import deepcopy +from datetime import datetime +from pathlib import Path from typing import Optional import gradio as gr +from filelock import FileLock from ktem.app import BasePage from ktem.components import reasonings from ktem.db.models import Conversation, engine @@ -10,6 +16,9 @@ from ktem.index.file.ui import File from ktem.reasoning.prompt_optimization.suggest_conversation_name import ( SuggestConvNamePipeline, ) +from ktem.reasoning.prompt_optimization.suggest_followup_chat import ( + SuggestFollowupQuesPipeline, +) from plotly.io import from_json from sqlmodel import Session, select from theflow.settings import settings as flowsettings @@ -17,8 +26,8 @@ from theflow.settings import settings as flowsettings from kotaemon.base import Document from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS +from ...utils import SUPPORTED_LANGUAGE_MAP from .chat_panel import ChatPanel -from .chat_suggestion import ChatSuggestion from .common import STATE from .control import ConversationControl from .report import ReportIssue @@ -50,6 +59,7 @@ class ChatPage(BasePage): self._reasoning_type = gr.State(value=None) self._llm_type = gr.State(value=None) self._conversation_renamed = gr.State(value=False) + self._suggestion_updated = gr.State(value=False) self._info_panel_expanded = gr.State(value=True) def on_building_ui(self): @@ -58,13 +68,11 @@ class ChatPage(BasePage): self.state_retrieval_history = gr.State([]) self.state_plot_history = gr.State([]) self.state_plot_panel = gr.State(None) + self.state_follow_up = gr.State(None) with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column: self.chat_control = ConversationControl(self._app) - if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False): - self.chat_suggestion = ChatSuggestion(self._app) - for index_id, index in enumerate(self._app.index_manager.indices): index.selector = None index_ui = index.get_selector_component_ui() @@ -156,6 +164,11 @@ class ChatPage(BasePage): return plot def on_register_events(self): + if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False): + self.state_follow_up = self.chat_control.chat_suggestion.example + else: + self.state_follow_up = self.chat_control.followup_suggestions + gr.on( triggers=[ self.chat_panel.text_input.submit, @@ -168,6 +181,7 @@ class ChatPage(BasePage): self._app.user_id, self.chat_control.conversation_id, self.chat_control.conversation_rn, + self.state_follow_up, ], outputs=[ self.chat_panel.text_input, @@ -175,6 +189,7 @@ class ChatPage(BasePage): self.chat_control.conversation_id, self.chat_control.conversation, self.chat_control.conversation_rn, + self.state_follow_up, ], concurrency_limit=20, show_progress="hidden", @@ -225,6 +240,30 @@ class ChatPage(BasePage): self.chat_control.conversation_rn, ], show_progress="hidden", + ).then( + fn=self.suggest_chat_conv, + inputs=[ + self._app.settings_state, + self.chat_panel.chatbot, + ], + outputs=[ + self.state_follow_up, + self._suggestion_updated, + ], + show_progress="hidden", + ).success( + self.chat_control.update_chat_suggestions, + inputs=[ + self.chat_control.conversation_id, + self.state_follow_up, + self._suggestion_updated, + self._app.user_id, + ], + outputs=[ + self.chat_control.conversation, + self.chat_control.conversation, + ], + show_progress="hidden", ).then( fn=self.persist_data_source, inputs=[ @@ -292,6 +331,30 @@ class ChatPage(BasePage): self.chat_control.conversation_rn, ], show_progress="hidden", + ).then( + fn=self.suggest_chat_conv, + inputs=[ + self._app.settings_state, + self.chat_panel.chatbot, + ], + outputs=[ + self.state_follow_up, + self._suggestion_updated, + ], + show_progress="hidden", + ).success( + self.chat_control.update_chat_suggestions, + inputs=[ + self.chat_control.conversation_id, + self.state_follow_up, + self._suggestion_updated, + self._app.user_id, + ], + outputs=[ + self.chat_control.conversation, + self.chat_control.conversation, + ], + show_progress="hidden", ).then( fn=self.persist_data_source, inputs=[ @@ -339,6 +402,7 @@ class ChatPage(BasePage): self.chat_control.conversation, self.chat_control.conversation_rn, self.chat_panel.chatbot, + self.state_follow_up, self.info_panel, self.state_plot_panel, self.state_retrieval_history, @@ -372,6 +436,7 @@ class ChatPage(BasePage): self.chat_control.conversation, self.chat_control.conversation_rn, self.chat_panel.chatbot, + self.state_follow_up, self.info_panel, self.state_plot_panel, self.state_retrieval_history, @@ -423,6 +488,7 @@ class ChatPage(BasePage): self.chat_control.conversation, self.chat_control.conversation_rn, self.chat_panel.chatbot, + self.state_follow_up, self.info_panel, self.state_plot_panel, self.state_retrieval_history, @@ -501,13 +567,15 @@ class ChatPage(BasePage): ) if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False): - self.chat_suggestion.example.select( - self.chat_suggestion.select_example, + self.state_follow_up.select( + self.chat_control.chat_suggestion.select_example, outputs=[self.chat_panel.text_input], show_progress="hidden", ) - def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name): + def submit_msg( + self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest + ): """Submit a message to the chatbot""" if not chat_input: raise ValueError("Input is empty") @@ -517,13 +585,20 @@ class ChatPage(BasePage): with Session(engine) as session: statement = select(Conversation).where(Conversation.id == id_) name = session.exec(statement).one().name + suggestion = ( + session.exec(statement) + .one() + .data_source.get("chat_suggestions", []) + ) new_conv_id = id_ conv_update = update new_conv_name = name + new_chat_suggestion = suggestion else: new_conv_id = conv_id conv_update = gr.update() new_conv_name = conv_name + new_chat_suggestion = chat_suggest return ( "", @@ -531,6 +606,7 @@ class ChatPage(BasePage): new_conv_id, conv_update, new_conv_name, + new_chat_suggestion, ) def toggle_delete(self, conv_id): @@ -872,3 +948,118 @@ class ChatPage(BasePage): renamed = True return new_name, renamed + + def suggest_chat_conv(self, settings, chat_history): + suggest_pipeline = SuggestFollowupQuesPipeline() + suggest_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get( + settings["reasoning.lang"], "English" + ) + + updated = False + + suggested_ques = [] + if len(chat_history) >= 1: + suggested_resp = suggest_pipeline(chat_history).text + if ques_res := re.search(r"\[(.*?)\]", re.sub("\n", "", suggested_resp)): + ques_res_str = ques_res.group() + try: + suggested_ques = ast.literal_eval(ques_res_str) + suggested_ques = [[x] for x in suggested_ques] + updated = True + except Exception: + pass + + return suggested_ques, updated + + def backup_original_info( + self, chat_history, settings, info_pannel, original_chat_history + ): + original_chat_history.append(chat_history[-1]) + return original_chat_history, settings, info_pannel + + def save_log( + self, + conversation_id, + chat_history, + settings, + info_panel, + original_chat_history, + original_settings, + original_info_panel, + log_dir, + ): + if not Path(log_dir).exists(): + Path(log_dir).mkdir(parents=True) + + lock = FileLock(Path(log_dir) / ".lock") + # get current date + today = datetime.now() + formatted_date = today.strftime("%d%m%Y_%H") + + with Session(engine) as session: + statement = select(Conversation).where(Conversation.id == conversation_id) + result = session.exec(statement).one() + + data_source = deepcopy(result.data_source) + likes = data_source.get("likes", []) + if not likes: + return + + feedback = likes[-1][-1] + message_index = likes[-1][0] + + current_message = chat_history[message_index[0]] + original_message = original_chat_history[message_index[0]] + is_original = all( + [ + current_item == original_item + for current_item, original_item in zip( + current_message, original_message + ) + ] + ) + + dataframe = [ + [ + conversation_id, + message_index, + current_message[0], + current_message[1], + chat_history, + settings, + info_panel, + feedback, + is_original, + original_message[1], + original_chat_history, + original_settings, + original_info_panel, + ] + ] + + with lock: + log_file = Path(log_dir) / f"{formatted_date}_log.csv" + is_log_file_exist = log_file.is_file() + with open(log_file, "a") as f: + writer = csv.writer(f) + # write headers + if not is_log_file_exist: + writer.writerow( + [ + "Conversation ID", + "Message ID", + "Question", + "Answer", + "Chat History", + "Settings", + "Evidences", + "Feedback", + "Original/ Rewritten", + "Original Answer", + "Original Chat History", + "Original Settings", + "Original Evidences", + ] + ) + + writer.writerows(dataframe) diff --git a/libs/ktem/ktem/pages/chat/control.py b/libs/ktem/ktem/pages/chat/control.py index 6fd47eb..5a9f0bf 100644 --- a/libs/ktem/ktem/pages/chat/control.py +++ b/libs/ktem/ktem/pages/chat/control.py @@ -1,5 +1,6 @@ import logging import os +from copy import deepcopy import gradio as gr from ktem.app import BasePage @@ -9,6 +10,7 @@ from sqlmodel import Session, or_, select import flowsettings from ...utils.conversation import sync_retrieval_n_message +from .chat_suggestion import ChatSuggestion from .common import STATE logger = logging.getLogger(__name__) @@ -103,6 +105,10 @@ class ConversationControl(BasePage): visible=False, ) + self.followup_suggestions = gr.State([]) + if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False): + self.chat_suggestion = ChatSuggestion(self._app) + def load_chat_history(self, user_id): """Reload chat history""" @@ -220,6 +226,8 @@ class ConversationControl(BasePage): chats = result.data_source.get("messages", []) + chat_suggestions = result.data_source.get("chat_suggestions", []) + retrieval_history: list[str] = result.data_source.get( "retrieval_messages", [] ) @@ -243,6 +251,7 @@ class ConversationControl(BasePage): name = "" selected = {} chats = [] + chat_suggestions = [] retrieval_history = [] plot_history = [] info_panel = "" @@ -265,6 +274,7 @@ class ConversationControl(BasePage): id_, name, chats, + chat_suggestions, info_panel, plot_data, retrieval_history, @@ -311,6 +321,46 @@ class ConversationControl(BasePage): gr.update(visible=False), ) + def update_chat_suggestions( + self, conversation_id, new_suggestions, is_updated, user_id + ): + """Update the conversation's chat suggestions""" + if not is_updated: + return ( + gr.update(), + conversation_id, + gr.update(visible=False), + ) + + if user_id is None: + gr.Warning("Please sign in first (Settings → User Settings)") + return gr.update(), "" + + if not conversation_id: + gr.Warning("No conversation selected.") + return gr.update(), "" + + with Session(engine) as session: + statement = select(Conversation).where(Conversation.id == conversation_id) + result = session.exec(statement).one() + + data_source = deepcopy(result.data_source) + data_source["chat_suggestions"] = [ + [x] for x in new_suggestions.iloc[:, 0].tolist() + ] + + result.data_source = data_source + session.add(result) + session.commit() + + history = self.load_chat_history(user_id) + gr.Info("Chat suggestions updated.") + return ( + gr.update(choices=history), + conversation_id, + gr.update(visible=False), + ) + def _on_app_created(self): """Reload the conversation once the app is created""" self._app.app.load( diff --git a/libs/ktem/ktem/reasoning/prompt_optimization/suggest_followup_chat.py b/libs/ktem/ktem/reasoning/prompt_optimization/suggest_followup_chat.py new file mode 100644 index 0000000..53d46b3 --- /dev/null +++ b/libs/ktem/ktem/reasoning/prompt_optimization/suggest_followup_chat.py @@ -0,0 +1,45 @@ +import logging + +from ktem.llms.manager import llms + +from kotaemon.base import AIMessage, BaseComponent, Document, HumanMessage, Node +from kotaemon.llms import ChatLLM, PromptTemplate + +logger = logging.getLogger(__name__) + + +class SuggestFollowupQuesPipeline(BaseComponent): + """Suggest a list of follow-up questions based on the chat history.""" + + llm: ChatLLM = Node(default_callback=lambda _: llms.get_default()) + SUGGEST_QUESTIONS_PROMPT_TEMPLATE = ( + "Based on the chat history above. " + "your task is to generate 3 to 5 relevant follow-up questions. " + "These questions should be simple, clear, " + "and designed to guide the conversation further. " + "Ensure that the questions are open-ended to encourage detailed responses. " + "Respond in JSON format with 'questions' key. " + "Answer using the language {lang} same as the question. " + "If the question uses Chinese, the answer should be in Chinese.\n" + ) + prompt_template: str = SUGGEST_QUESTIONS_PROMPT_TEMPLATE + extra_prompt: str = """Example of valid response: +```json +{ + "questions": ["the weather is good", "what's your favorite city"] +} +```""" + lang: str = "English" + + def run(self, chat_history: list[tuple[str, str]]) -> Document: + prompt_template = PromptTemplate(self.prompt_template) + prompt = prompt_template.populate(lang=self.lang) + self.extra_prompt + + messages = [] + for human, ai in chat_history[-3:]: + messages.append(HumanMessage(content=human)) + messages.append(AIMessage(content=ai)) + + messages.append(HumanMessage(content=prompt)) + + return self.llm(messages)