feat: tweak the 'Chat suggestion' feature to tie it to conversations (#341) #none

Signed-off-by: Kennywu <jdlow@live.cn>
This commit is contained in:
KennyWu 2024-10-10 12:02:04 +08:00 committed by GitHub
parent 96f58a445a
commit 49a083fd9f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 295 additions and 8 deletions

View File

@ -63,6 +63,7 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
KH_DOC_DIR = this_dir / "docs"
KH_MODE = "dev"
KH_FEATURE_CHAT_SUGGESTION = config("KH_FEATURE_CHAT_SUGGESTION", default=False)
KH_FEATURE_USER_MANAGEMENT = config(
"KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool
)

View File

@ -34,7 +34,7 @@ class BaseConversation(SQLModel):
is_public: bool = Field(default=False)
# contains messages + current files
# contains messages + current files + chat_suggestions
data_source: dict = Field(default={}, sa_column=Column(JSON))
date_created: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)

View File

@ -1,8 +1,14 @@
import ast
import asyncio
import csv
import re
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Optional
import gradio as gr
from filelock import FileLock
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
@ -10,6 +16,9 @@ from ktem.index.file.ui import File
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
SuggestConvNamePipeline,
)
from ktem.reasoning.prompt_optimization.suggest_followup_chat import (
SuggestFollowupQuesPipeline,
)
from plotly.io import from_json
from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
@ -17,8 +26,8 @@ from theflow.settings import settings as flowsettings
from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from ...utils import SUPPORTED_LANGUAGE_MAP
from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
from .common import STATE
from .control import ConversationControl
from .report import ReportIssue
@ -50,6 +59,7 @@ class ChatPage(BasePage):
self._reasoning_type = gr.State(value=None)
self._llm_type = gr.State(value=None)
self._conversation_renamed = gr.State(value=False)
self._suggestion_updated = gr.State(value=False)
self._info_panel_expanded = gr.State(value=True)
def on_building_ui(self):
@ -58,13 +68,11 @@ class ChatPage(BasePage):
self.state_retrieval_history = gr.State([])
self.state_plot_history = gr.State([])
self.state_plot_panel = gr.State(None)
self.state_follow_up = gr.State(None)
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
self.chat_control = ConversationControl(self._app)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion = ChatSuggestion(self._app)
for index_id, index in enumerate(self._app.index_manager.indices):
index.selector = None
index_ui = index.get_selector_component_ui()
@ -156,6 +164,11 @@ class ChatPage(BasePage):
return plot
def on_register_events(self):
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.state_follow_up = self.chat_control.chat_suggestion.example
else:
self.state_follow_up = self.chat_control.followup_suggestions
gr.on(
triggers=[
self.chat_panel.text_input.submit,
@ -168,6 +181,7 @@ class ChatPage(BasePage):
self._app.user_id,
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self.state_follow_up,
],
outputs=[
self.chat_panel.text_input,
@ -175,6 +189,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.state_follow_up,
],
concurrency_limit=20,
show_progress="hidden",
@ -225,6 +240,30 @@ class ChatPage(BasePage):
self.chat_control.conversation_rn,
],
show_progress="hidden",
).then(
fn=self.suggest_chat_conv,
inputs=[
self._app.settings_state,
self.chat_panel.chatbot,
],
outputs=[
self.state_follow_up,
self._suggestion_updated,
],
show_progress="hidden",
).success(
self.chat_control.update_chat_suggestions,
inputs=[
self.chat_control.conversation_id,
self.state_follow_up,
self._suggestion_updated,
self._app.user_id,
],
outputs=[
self.chat_control.conversation,
self.chat_control.conversation,
],
show_progress="hidden",
).then(
fn=self.persist_data_source,
inputs=[
@ -292,6 +331,30 @@ class ChatPage(BasePage):
self.chat_control.conversation_rn,
],
show_progress="hidden",
).then(
fn=self.suggest_chat_conv,
inputs=[
self._app.settings_state,
self.chat_panel.chatbot,
],
outputs=[
self.state_follow_up,
self._suggestion_updated,
],
show_progress="hidden",
).success(
self.chat_control.update_chat_suggestions,
inputs=[
self.chat_control.conversation_id,
self.state_follow_up,
self._suggestion_updated,
self._app.user_id,
],
outputs=[
self.chat_control.conversation,
self.chat_control.conversation,
],
show_progress="hidden",
).then(
fn=self.persist_data_source,
inputs=[
@ -339,6 +402,7 @@ class ChatPage(BasePage):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.state_follow_up,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
@ -372,6 +436,7 @@ class ChatPage(BasePage):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.state_follow_up,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
@ -423,6 +488,7 @@ class ChatPage(BasePage):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.state_follow_up,
self.info_panel,
self.state_plot_panel,
self.state_retrieval_history,
@ -501,13 +567,15 @@ class ChatPage(BasePage):
)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion.example.select(
self.chat_suggestion.select_example,
self.state_follow_up.select(
self.chat_control.chat_suggestion.select_example,
outputs=[self.chat_panel.text_input],
show_progress="hidden",
)
def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name):
def submit_msg(
self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest
):
"""Submit a message to the chatbot"""
if not chat_input:
raise ValueError("Input is empty")
@ -517,13 +585,20 @@ class ChatPage(BasePage):
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == id_)
name = session.exec(statement).one().name
suggestion = (
session.exec(statement)
.one()
.data_source.get("chat_suggestions", [])
)
new_conv_id = id_
conv_update = update
new_conv_name = name
new_chat_suggestion = suggestion
else:
new_conv_id = conv_id
conv_update = gr.update()
new_conv_name = conv_name
new_chat_suggestion = chat_suggest
return (
"",
@ -531,6 +606,7 @@ class ChatPage(BasePage):
new_conv_id,
conv_update,
new_conv_name,
new_chat_suggestion,
)
def toggle_delete(self, conv_id):
@ -872,3 +948,118 @@ class ChatPage(BasePage):
renamed = True
return new_name, renamed
def suggest_chat_conv(self, settings, chat_history):
suggest_pipeline = SuggestFollowupQuesPipeline()
suggest_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
settings["reasoning.lang"], "English"
)
updated = False
suggested_ques = []
if len(chat_history) >= 1:
suggested_resp = suggest_pipeline(chat_history).text
if ques_res := re.search(r"\[(.*?)\]", re.sub("\n", "", suggested_resp)):
ques_res_str = ques_res.group()
try:
suggested_ques = ast.literal_eval(ques_res_str)
suggested_ques = [[x] for x in suggested_ques]
updated = True
except Exception:
pass
return suggested_ques, updated
def backup_original_info(
self, chat_history, settings, info_pannel, original_chat_history
):
original_chat_history.append(chat_history[-1])
return original_chat_history, settings, info_pannel
def save_log(
self,
conversation_id,
chat_history,
settings,
info_panel,
original_chat_history,
original_settings,
original_info_panel,
log_dir,
):
if not Path(log_dir).exists():
Path(log_dir).mkdir(parents=True)
lock = FileLock(Path(log_dir) / ".lock")
# get current date
today = datetime.now()
formatted_date = today.strftime("%d%m%Y_%H")
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
data_source = deepcopy(result.data_source)
likes = data_source.get("likes", [])
if not likes:
return
feedback = likes[-1][-1]
message_index = likes[-1][0]
current_message = chat_history[message_index[0]]
original_message = original_chat_history[message_index[0]]
is_original = all(
[
current_item == original_item
for current_item, original_item in zip(
current_message, original_message
)
]
)
dataframe = [
[
conversation_id,
message_index,
current_message[0],
current_message[1],
chat_history,
settings,
info_panel,
feedback,
is_original,
original_message[1],
original_chat_history,
original_settings,
original_info_panel,
]
]
with lock:
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
is_log_file_exist = log_file.is_file()
with open(log_file, "a") as f:
writer = csv.writer(f)
# write headers
if not is_log_file_exist:
writer.writerow(
[
"Conversation ID",
"Message ID",
"Question",
"Answer",
"Chat History",
"Settings",
"Evidences",
"Feedback",
"Original/ Rewritten",
"Original Answer",
"Original Chat History",
"Original Settings",
"Original Evidences",
]
)
writer.writerows(dataframe)

View File

@ -1,5 +1,6 @@
import logging
import os
from copy import deepcopy
import gradio as gr
from ktem.app import BasePage
@ -9,6 +10,7 @@ from sqlmodel import Session, or_, select
import flowsettings
from ...utils.conversation import sync_retrieval_n_message
from .chat_suggestion import ChatSuggestion
from .common import STATE
logger = logging.getLogger(__name__)
@ -103,6 +105,10 @@ class ConversationControl(BasePage):
visible=False,
)
self.followup_suggestions = gr.State([])
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion = ChatSuggestion(self._app)
def load_chat_history(self, user_id):
"""Reload chat history"""
@ -220,6 +226,8 @@ class ConversationControl(BasePage):
chats = result.data_source.get("messages", [])
chat_suggestions = result.data_source.get("chat_suggestions", [])
retrieval_history: list[str] = result.data_source.get(
"retrieval_messages", []
)
@ -243,6 +251,7 @@ class ConversationControl(BasePage):
name = ""
selected = {}
chats = []
chat_suggestions = []
retrieval_history = []
plot_history = []
info_panel = ""
@ -265,6 +274,7 @@ class ConversationControl(BasePage):
id_,
name,
chats,
chat_suggestions,
info_panel,
plot_data,
retrieval_history,
@ -311,6 +321,46 @@ class ConversationControl(BasePage):
gr.update(visible=False),
)
def update_chat_suggestions(
self, conversation_id, new_suggestions, is_updated, user_id
):
"""Update the conversation's chat suggestions"""
if not is_updated:
return (
gr.update(),
conversation_id,
gr.update(visible=False),
)
if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)")
return gr.update(), ""
if not conversation_id:
gr.Warning("No conversation selected.")
return gr.update(), ""
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one()
data_source = deepcopy(result.data_source)
data_source["chat_suggestions"] = [
[x] for x in new_suggestions.iloc[:, 0].tolist()
]
result.data_source = data_source
session.add(result)
session.commit()
history = self.load_chat_history(user_id)
gr.Info("Chat suggestions updated.")
return (
gr.update(choices=history),
conversation_id,
gr.update(visible=False),
)
def _on_app_created(self):
"""Reload the conversation once the app is created"""
self._app.app.load(

View File

@ -0,0 +1,45 @@
import logging
from ktem.llms.manager import llms
from kotaemon.base import AIMessage, BaseComponent, Document, HumanMessage, Node
from kotaemon.llms import ChatLLM, PromptTemplate
logger = logging.getLogger(__name__)
class SuggestFollowupQuesPipeline(BaseComponent):
"""Suggest a list of follow-up questions based on the chat history."""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
SUGGEST_QUESTIONS_PROMPT_TEMPLATE = (
"Based on the chat history above. "
"your task is to generate 3 to 5 relevant follow-up questions. "
"These questions should be simple, clear, "
"and designed to guide the conversation further. "
"Ensure that the questions are open-ended to encourage detailed responses. "
"Respond in JSON format with 'questions' key. "
"Answer using the language {lang} same as the question. "
"If the question uses Chinese, the answer should be in Chinese.\n"
)
prompt_template: str = SUGGEST_QUESTIONS_PROMPT_TEMPLATE
extra_prompt: str = """Example of valid response:
```json
{
"questions": ["the weather is good", "what's your favorite city"]
}
```"""
lang: str = "English"
def run(self, chat_history: list[tuple[str, str]]) -> Document:
prompt_template = PromptTemplate(self.prompt_template)
prompt = prompt_template.populate(lang=self.lang) + self.extra_prompt
messages = []
for human, ai in chat_history[-3:]:
messages.append(HumanMessage(content=human))
messages.append(AIMessage(content=ai))
messages.append(HumanMessage(content=prompt))
return self.llm(messages)