feat: tweak the 'Chat suggestion' feature to tie it to conversations (#341) #none
Signed-off-by: Kennywu <jdlow@live.cn>
This commit is contained in:
parent
96f58a445a
commit
49a083fd9f
|
@ -63,6 +63,7 @@ os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
|
||||||
KH_DOC_DIR = this_dir / "docs"
|
KH_DOC_DIR = this_dir / "docs"
|
||||||
|
|
||||||
KH_MODE = "dev"
|
KH_MODE = "dev"
|
||||||
|
KH_FEATURE_CHAT_SUGGESTION = config("KH_FEATURE_CHAT_SUGGESTION", default=False)
|
||||||
KH_FEATURE_USER_MANAGEMENT = config(
|
KH_FEATURE_USER_MANAGEMENT = config(
|
||||||
"KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool
|
"KH_FEATURE_USER_MANAGEMENT", default=True, cast=bool
|
||||||
)
|
)
|
||||||
|
|
|
@ -34,7 +34,7 @@ class BaseConversation(SQLModel):
|
||||||
|
|
||||||
is_public: bool = Field(default=False)
|
is_public: bool = Field(default=False)
|
||||||
|
|
||||||
# contains messages + current files
|
# contains messages + current files + chat_suggestions
|
||||||
data_source: dict = Field(default={}, sa_column=Column(JSON))
|
data_source: dict = Field(default={}, sa_column=Column(JSON))
|
||||||
|
|
||||||
date_created: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
date_created: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||||
|
|
|
@ -1,8 +1,14 @@
|
||||||
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import csv
|
||||||
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from filelock import FileLock
|
||||||
from ktem.app import BasePage
|
from ktem.app import BasePage
|
||||||
from ktem.components import reasonings
|
from ktem.components import reasonings
|
||||||
from ktem.db.models import Conversation, engine
|
from ktem.db.models import Conversation, engine
|
||||||
|
@ -10,6 +16,9 @@ from ktem.index.file.ui import File
|
||||||
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
|
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
|
||||||
SuggestConvNamePipeline,
|
SuggestConvNamePipeline,
|
||||||
)
|
)
|
||||||
|
from ktem.reasoning.prompt_optimization.suggest_followup_chat import (
|
||||||
|
SuggestFollowupQuesPipeline,
|
||||||
|
)
|
||||||
from plotly.io import from_json
|
from plotly.io import from_json
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
@ -17,8 +26,8 @@ from theflow.settings import settings as flowsettings
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||||
|
|
||||||
|
from ...utils import SUPPORTED_LANGUAGE_MAP
|
||||||
from .chat_panel import ChatPanel
|
from .chat_panel import ChatPanel
|
||||||
from .chat_suggestion import ChatSuggestion
|
|
||||||
from .common import STATE
|
from .common import STATE
|
||||||
from .control import ConversationControl
|
from .control import ConversationControl
|
||||||
from .report import ReportIssue
|
from .report import ReportIssue
|
||||||
|
@ -50,6 +59,7 @@ class ChatPage(BasePage):
|
||||||
self._reasoning_type = gr.State(value=None)
|
self._reasoning_type = gr.State(value=None)
|
||||||
self._llm_type = gr.State(value=None)
|
self._llm_type = gr.State(value=None)
|
||||||
self._conversation_renamed = gr.State(value=False)
|
self._conversation_renamed = gr.State(value=False)
|
||||||
|
self._suggestion_updated = gr.State(value=False)
|
||||||
self._info_panel_expanded = gr.State(value=True)
|
self._info_panel_expanded = gr.State(value=True)
|
||||||
|
|
||||||
def on_building_ui(self):
|
def on_building_ui(self):
|
||||||
|
@ -58,13 +68,11 @@ class ChatPage(BasePage):
|
||||||
self.state_retrieval_history = gr.State([])
|
self.state_retrieval_history = gr.State([])
|
||||||
self.state_plot_history = gr.State([])
|
self.state_plot_history = gr.State([])
|
||||||
self.state_plot_panel = gr.State(None)
|
self.state_plot_panel = gr.State(None)
|
||||||
|
self.state_follow_up = gr.State(None)
|
||||||
|
|
||||||
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
|
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
|
||||||
self.chat_control = ConversationControl(self._app)
|
self.chat_control = ConversationControl(self._app)
|
||||||
|
|
||||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
|
||||||
self.chat_suggestion = ChatSuggestion(self._app)
|
|
||||||
|
|
||||||
for index_id, index in enumerate(self._app.index_manager.indices):
|
for index_id, index in enumerate(self._app.index_manager.indices):
|
||||||
index.selector = None
|
index.selector = None
|
||||||
index_ui = index.get_selector_component_ui()
|
index_ui = index.get_selector_component_ui()
|
||||||
|
@ -156,6 +164,11 @@ class ChatPage(BasePage):
|
||||||
return plot
|
return plot
|
||||||
|
|
||||||
def on_register_events(self):
|
def on_register_events(self):
|
||||||
|
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||||
|
self.state_follow_up = self.chat_control.chat_suggestion.example
|
||||||
|
else:
|
||||||
|
self.state_follow_up = self.chat_control.followup_suggestions
|
||||||
|
|
||||||
gr.on(
|
gr.on(
|
||||||
triggers=[
|
triggers=[
|
||||||
self.chat_panel.text_input.submit,
|
self.chat_panel.text_input.submit,
|
||||||
|
@ -168,6 +181,7 @@ class ChatPage(BasePage):
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
|
self.state_follow_up,
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
self.chat_panel.text_input,
|
self.chat_panel.text_input,
|
||||||
|
@ -175,6 +189,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
|
self.state_follow_up,
|
||||||
],
|
],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
@ -225,6 +240,30 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
],
|
],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
fn=self.suggest_chat_conv,
|
||||||
|
inputs=[
|
||||||
|
self._app.settings_state,
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.state_follow_up,
|
||||||
|
self._suggestion_updated,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
).success(
|
||||||
|
self.chat_control.update_chat_suggestions,
|
||||||
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
|
self.state_follow_up,
|
||||||
|
self._suggestion_updated,
|
||||||
|
self._app.user_id,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.chat_control.conversation,
|
||||||
|
self.chat_control.conversation,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
fn=self.persist_data_source,
|
fn=self.persist_data_source,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
@ -292,6 +331,30 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
],
|
],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
fn=self.suggest_chat_conv,
|
||||||
|
inputs=[
|
||||||
|
self._app.settings_state,
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.state_follow_up,
|
||||||
|
self._suggestion_updated,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
).success(
|
||||||
|
self.chat_control.update_chat_suggestions,
|
||||||
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
|
self.state_follow_up,
|
||||||
|
self._suggestion_updated,
|
||||||
|
self._app.user_id,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.chat_control.conversation,
|
||||||
|
self.chat_control.conversation,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
fn=self.persist_data_source,
|
fn=self.persist_data_source,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
@ -339,6 +402,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
|
self.state_follow_up,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.state_plot_panel,
|
self.state_plot_panel,
|
||||||
self.state_retrieval_history,
|
self.state_retrieval_history,
|
||||||
|
@ -372,6 +436,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
|
self.state_follow_up,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.state_plot_panel,
|
self.state_plot_panel,
|
||||||
self.state_retrieval_history,
|
self.state_retrieval_history,
|
||||||
|
@ -423,6 +488,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
|
self.state_follow_up,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.state_plot_panel,
|
self.state_plot_panel,
|
||||||
self.state_retrieval_history,
|
self.state_retrieval_history,
|
||||||
|
@ -501,13 +567,15 @@ class ChatPage(BasePage):
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||||
self.chat_suggestion.example.select(
|
self.state_follow_up.select(
|
||||||
self.chat_suggestion.select_example,
|
self.chat_control.chat_suggestion.select_example,
|
||||||
outputs=[self.chat_panel.text_input],
|
outputs=[self.chat_panel.text_input],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
|
||||||
def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name):
|
def submit_msg(
|
||||||
|
self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest
|
||||||
|
):
|
||||||
"""Submit a message to the chatbot"""
|
"""Submit a message to the chatbot"""
|
||||||
if not chat_input:
|
if not chat_input:
|
||||||
raise ValueError("Input is empty")
|
raise ValueError("Input is empty")
|
||||||
|
@ -517,13 +585,20 @@ class ChatPage(BasePage):
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
statement = select(Conversation).where(Conversation.id == id_)
|
statement = select(Conversation).where(Conversation.id == id_)
|
||||||
name = session.exec(statement).one().name
|
name = session.exec(statement).one().name
|
||||||
|
suggestion = (
|
||||||
|
session.exec(statement)
|
||||||
|
.one()
|
||||||
|
.data_source.get("chat_suggestions", [])
|
||||||
|
)
|
||||||
new_conv_id = id_
|
new_conv_id = id_
|
||||||
conv_update = update
|
conv_update = update
|
||||||
new_conv_name = name
|
new_conv_name = name
|
||||||
|
new_chat_suggestion = suggestion
|
||||||
else:
|
else:
|
||||||
new_conv_id = conv_id
|
new_conv_id = conv_id
|
||||||
conv_update = gr.update()
|
conv_update = gr.update()
|
||||||
new_conv_name = conv_name
|
new_conv_name = conv_name
|
||||||
|
new_chat_suggestion = chat_suggest
|
||||||
|
|
||||||
return (
|
return (
|
||||||
"",
|
"",
|
||||||
|
@ -531,6 +606,7 @@ class ChatPage(BasePage):
|
||||||
new_conv_id,
|
new_conv_id,
|
||||||
conv_update,
|
conv_update,
|
||||||
new_conv_name,
|
new_conv_name,
|
||||||
|
new_chat_suggestion,
|
||||||
)
|
)
|
||||||
|
|
||||||
def toggle_delete(self, conv_id):
|
def toggle_delete(self, conv_id):
|
||||||
|
@ -872,3 +948,118 @@ class ChatPage(BasePage):
|
||||||
renamed = True
|
renamed = True
|
||||||
|
|
||||||
return new_name, renamed
|
return new_name, renamed
|
||||||
|
|
||||||
|
def suggest_chat_conv(self, settings, chat_history):
|
||||||
|
suggest_pipeline = SuggestFollowupQuesPipeline()
|
||||||
|
suggest_pipeline.lang = SUPPORTED_LANGUAGE_MAP.get(
|
||||||
|
settings["reasoning.lang"], "English"
|
||||||
|
)
|
||||||
|
|
||||||
|
updated = False
|
||||||
|
|
||||||
|
suggested_ques = []
|
||||||
|
if len(chat_history) >= 1:
|
||||||
|
suggested_resp = suggest_pipeline(chat_history).text
|
||||||
|
if ques_res := re.search(r"\[(.*?)\]", re.sub("\n", "", suggested_resp)):
|
||||||
|
ques_res_str = ques_res.group()
|
||||||
|
try:
|
||||||
|
suggested_ques = ast.literal_eval(ques_res_str)
|
||||||
|
suggested_ques = [[x] for x in suggested_ques]
|
||||||
|
updated = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return suggested_ques, updated
|
||||||
|
|
||||||
|
def backup_original_info(
|
||||||
|
self, chat_history, settings, info_pannel, original_chat_history
|
||||||
|
):
|
||||||
|
original_chat_history.append(chat_history[-1])
|
||||||
|
return original_chat_history, settings, info_pannel
|
||||||
|
|
||||||
|
def save_log(
|
||||||
|
self,
|
||||||
|
conversation_id,
|
||||||
|
chat_history,
|
||||||
|
settings,
|
||||||
|
info_panel,
|
||||||
|
original_chat_history,
|
||||||
|
original_settings,
|
||||||
|
original_info_panel,
|
||||||
|
log_dir,
|
||||||
|
):
|
||||||
|
if not Path(log_dir).exists():
|
||||||
|
Path(log_dir).mkdir(parents=True)
|
||||||
|
|
||||||
|
lock = FileLock(Path(log_dir) / ".lock")
|
||||||
|
# get current date
|
||||||
|
today = datetime.now()
|
||||||
|
formatted_date = today.strftime("%d%m%Y_%H")
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||||
|
result = session.exec(statement).one()
|
||||||
|
|
||||||
|
data_source = deepcopy(result.data_source)
|
||||||
|
likes = data_source.get("likes", [])
|
||||||
|
if not likes:
|
||||||
|
return
|
||||||
|
|
||||||
|
feedback = likes[-1][-1]
|
||||||
|
message_index = likes[-1][0]
|
||||||
|
|
||||||
|
current_message = chat_history[message_index[0]]
|
||||||
|
original_message = original_chat_history[message_index[0]]
|
||||||
|
is_original = all(
|
||||||
|
[
|
||||||
|
current_item == original_item
|
||||||
|
for current_item, original_item in zip(
|
||||||
|
current_message, original_message
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
dataframe = [
|
||||||
|
[
|
||||||
|
conversation_id,
|
||||||
|
message_index,
|
||||||
|
current_message[0],
|
||||||
|
current_message[1],
|
||||||
|
chat_history,
|
||||||
|
settings,
|
||||||
|
info_panel,
|
||||||
|
feedback,
|
||||||
|
is_original,
|
||||||
|
original_message[1],
|
||||||
|
original_chat_history,
|
||||||
|
original_settings,
|
||||||
|
original_info_panel,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
|
||||||
|
is_log_file_exist = log_file.is_file()
|
||||||
|
with open(log_file, "a") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
# write headers
|
||||||
|
if not is_log_file_exist:
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
"Conversation ID",
|
||||||
|
"Message ID",
|
||||||
|
"Question",
|
||||||
|
"Answer",
|
||||||
|
"Chat History",
|
||||||
|
"Settings",
|
||||||
|
"Evidences",
|
||||||
|
"Feedback",
|
||||||
|
"Original/ Rewritten",
|
||||||
|
"Original Answer",
|
||||||
|
"Original Chat History",
|
||||||
|
"Original Settings",
|
||||||
|
"Original Evidences",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
writer.writerows(dataframe)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from ktem.app import BasePage
|
from ktem.app import BasePage
|
||||||
|
@ -9,6 +10,7 @@ from sqlmodel import Session, or_, select
|
||||||
import flowsettings
|
import flowsettings
|
||||||
|
|
||||||
from ...utils.conversation import sync_retrieval_n_message
|
from ...utils.conversation import sync_retrieval_n_message
|
||||||
|
from .chat_suggestion import ChatSuggestion
|
||||||
from .common import STATE
|
from .common import STATE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -103,6 +105,10 @@ class ConversationControl(BasePage):
|
||||||
visible=False,
|
visible=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.followup_suggestions = gr.State([])
|
||||||
|
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||||
|
self.chat_suggestion = ChatSuggestion(self._app)
|
||||||
|
|
||||||
def load_chat_history(self, user_id):
|
def load_chat_history(self, user_id):
|
||||||
"""Reload chat history"""
|
"""Reload chat history"""
|
||||||
|
|
||||||
|
@ -220,6 +226,8 @@ class ConversationControl(BasePage):
|
||||||
|
|
||||||
chats = result.data_source.get("messages", [])
|
chats = result.data_source.get("messages", [])
|
||||||
|
|
||||||
|
chat_suggestions = result.data_source.get("chat_suggestions", [])
|
||||||
|
|
||||||
retrieval_history: list[str] = result.data_source.get(
|
retrieval_history: list[str] = result.data_source.get(
|
||||||
"retrieval_messages", []
|
"retrieval_messages", []
|
||||||
)
|
)
|
||||||
|
@ -243,6 +251,7 @@ class ConversationControl(BasePage):
|
||||||
name = ""
|
name = ""
|
||||||
selected = {}
|
selected = {}
|
||||||
chats = []
|
chats = []
|
||||||
|
chat_suggestions = []
|
||||||
retrieval_history = []
|
retrieval_history = []
|
||||||
plot_history = []
|
plot_history = []
|
||||||
info_panel = ""
|
info_panel = ""
|
||||||
|
@ -265,6 +274,7 @@ class ConversationControl(BasePage):
|
||||||
id_,
|
id_,
|
||||||
name,
|
name,
|
||||||
chats,
|
chats,
|
||||||
|
chat_suggestions,
|
||||||
info_panel,
|
info_panel,
|
||||||
plot_data,
|
plot_data,
|
||||||
retrieval_history,
|
retrieval_history,
|
||||||
|
@ -311,6 +321,46 @@ class ConversationControl(BasePage):
|
||||||
gr.update(visible=False),
|
gr.update(visible=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_chat_suggestions(
|
||||||
|
self, conversation_id, new_suggestions, is_updated, user_id
|
||||||
|
):
|
||||||
|
"""Update the conversation's chat suggestions"""
|
||||||
|
if not is_updated:
|
||||||
|
return (
|
||||||
|
gr.update(),
|
||||||
|
conversation_id,
|
||||||
|
gr.update(visible=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||||
|
return gr.update(), ""
|
||||||
|
|
||||||
|
if not conversation_id:
|
||||||
|
gr.Warning("No conversation selected.")
|
||||||
|
return gr.update(), ""
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||||
|
result = session.exec(statement).one()
|
||||||
|
|
||||||
|
data_source = deepcopy(result.data_source)
|
||||||
|
data_source["chat_suggestions"] = [
|
||||||
|
[x] for x in new_suggestions.iloc[:, 0].tolist()
|
||||||
|
]
|
||||||
|
|
||||||
|
result.data_source = data_source
|
||||||
|
session.add(result)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
history = self.load_chat_history(user_id)
|
||||||
|
gr.Info("Chat suggestions updated.")
|
||||||
|
return (
|
||||||
|
gr.update(choices=history),
|
||||||
|
conversation_id,
|
||||||
|
gr.update(visible=False),
|
||||||
|
)
|
||||||
|
|
||||||
def _on_app_created(self):
|
def _on_app_created(self):
|
||||||
"""Reload the conversation once the app is created"""
|
"""Reload the conversation once the app is created"""
|
||||||
self._app.app.load(
|
self._app.app.load(
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
|
||||||
|
from kotaemon.base import AIMessage, BaseComponent, Document, HumanMessage, Node
|
||||||
|
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestFollowupQuesPipeline(BaseComponent):
|
||||||
|
"""Suggest a list of follow-up questions based on the chat history."""
|
||||||
|
|
||||||
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||||
|
SUGGEST_QUESTIONS_PROMPT_TEMPLATE = (
|
||||||
|
"Based on the chat history above. "
|
||||||
|
"your task is to generate 3 to 5 relevant follow-up questions. "
|
||||||
|
"These questions should be simple, clear, "
|
||||||
|
"and designed to guide the conversation further. "
|
||||||
|
"Ensure that the questions are open-ended to encourage detailed responses. "
|
||||||
|
"Respond in JSON format with 'questions' key. "
|
||||||
|
"Answer using the language {lang} same as the question. "
|
||||||
|
"If the question uses Chinese, the answer should be in Chinese.\n"
|
||||||
|
)
|
||||||
|
prompt_template: str = SUGGEST_QUESTIONS_PROMPT_TEMPLATE
|
||||||
|
extra_prompt: str = """Example of valid response:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"questions": ["the weather is good", "what's your favorite city"]
|
||||||
|
}
|
||||||
|
```"""
|
||||||
|
lang: str = "English"
|
||||||
|
|
||||||
|
def run(self, chat_history: list[tuple[str, str]]) -> Document:
|
||||||
|
prompt_template = PromptTemplate(self.prompt_template)
|
||||||
|
prompt = prompt_template.populate(lang=self.lang) + self.extra_prompt
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for human, ai in chat_history[-3:]:
|
||||||
|
messages.append(HumanMessage(content=human))
|
||||||
|
messages.append(AIMessage(content=ai))
|
||||||
|
|
||||||
|
messages.append(HumanMessage(content=prompt))
|
||||||
|
|
||||||
|
return self.llm(messages)
|
Loading…
Reference in New Issue
Block a user