Feat/regenerate answer (#7)
* Add regen button and repharasing question on regen * Stop appending regen messages to history, allow only one * Add dynamic conversation state * Allow reasoning pipeline to manipulate state --------- Co-authored-by: albert <albert@cinnamon.is> Co-authored-by: Duc Nguyen (john) <trungduc1992@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from ktem.db.models import Conversation, engine
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .chat_panel import ChatPanel
|
||||
from .common import STATE
|
||||
from .control import ConversationControl
|
||||
from .report import ReportIssue
|
||||
|
||||
@@ -21,6 +22,7 @@ class ChatPage(BasePage):
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Row():
|
||||
self.chat_state = gr.State(STATE)
|
||||
with gr.Column(scale=1):
|
||||
self.chat_control = ConversationControl(self._app)
|
||||
|
||||
@@ -62,12 +64,13 @@ class ChatPage(BasePage):
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self._app.settings_state,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
outputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
self.chat_state,
|
||||
],
|
||||
show_progress="minimal",
|
||||
).then(
|
||||
@@ -75,6 +78,33 @@ class ChatPage(BasePage):
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
self.chat_panel.regen_btn.click(
|
||||
fn=self.regen_fn,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self._app.settings_state,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
outputs=[
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
self.chat_state,
|
||||
],
|
||||
show_progress="minimal",
|
||||
).then(
|
||||
fn=self.update_data_source,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
outputs=None,
|
||||
@@ -94,6 +124,7 @@ class ChatPage(BasePage):
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
show_progress="hidden",
|
||||
@@ -109,12 +140,13 @@ class ChatPage(BasePage):
|
||||
self.chat_panel.chatbot,
|
||||
self._app.settings_state,
|
||||
self._app.user_id,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
outputs=None,
|
||||
)
|
||||
|
||||
def update_data_source(self, convo_id, messages, *selecteds):
|
||||
def update_data_source(self, convo_id, messages, state, *selecteds):
|
||||
"""Update the data source"""
|
||||
if not convo_id:
|
||||
gr.Warning("No conversation selected")
|
||||
@@ -133,6 +165,7 @@ class ChatPage(BasePage):
|
||||
result.data_source = {
|
||||
"selected": selecteds_,
|
||||
"messages": messages,
|
||||
"state": state,
|
||||
"likes": deepcopy(data_source.get("likes", [])),
|
||||
}
|
||||
session.add(result)
|
||||
@@ -152,17 +185,22 @@ class ChatPage(BasePage):
|
||||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
def create_pipeline(self, settings: dict, *selecteds):
|
||||
def create_pipeline(self, settings: dict, state: dict, *selecteds):
|
||||
"""Create the pipeline from settings
|
||||
|
||||
Args:
|
||||
settings: the settings of the app
|
||||
is_regen: whether the regen button is clicked
|
||||
selected: the list of file ids that will be served as context. If None, then
|
||||
consider using all files
|
||||
|
||||
Returns:
|
||||
the pipeline objects
|
||||
- the pipeline objects
|
||||
"""
|
||||
reasoning_mode = settings["reasoning.use"]
|
||||
reasoning_cls = reasonings[reasoning_mode]
|
||||
reasoning_id = reasoning_cls.get_info()["id"]
|
||||
|
||||
# get retrievers
|
||||
retrievers = []
|
||||
for index in self._app.index_manager.indices:
|
||||
@@ -172,13 +210,17 @@ class ChatPage(BasePage):
|
||||
iretrievers = index.get_retriever_pipelines(settings, index_selected)
|
||||
retrievers += iretrievers
|
||||
|
||||
reasoning_mode = settings["reasoning.use"]
|
||||
reasoning_cls = reasonings[reasoning_mode]
|
||||
pipeline = reasoning_cls.get_pipeline(settings, retrievers)
|
||||
# prepare states
|
||||
reasoning_state = {
|
||||
"app": deepcopy(state["app"]),
|
||||
"pipeline": deepcopy(state.get(reasoning_id, {})),
|
||||
}
|
||||
|
||||
return pipeline
|
||||
pipeline = reasoning_cls.get_pipeline(settings, reasoning_state, retrievers)
|
||||
|
||||
async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
|
||||
return pipeline, reasoning_state
|
||||
|
||||
async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||
"""Chat function"""
|
||||
chat_input = chat_history[-1][0]
|
||||
chat_history = chat_history[:-1]
|
||||
@@ -186,7 +228,7 @@ class ChatPage(BasePage):
|
||||
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
||||
|
||||
# construct the pipeline
|
||||
pipeline = self.create_pipeline(settings, *selecteds)
|
||||
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
|
||||
pipeline.set_output_queue(queue)
|
||||
|
||||
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
||||
@@ -198,7 +240,8 @@ class ChatPage(BasePage):
|
||||
try:
|
||||
response = queue.get_nowait()
|
||||
except Exception:
|
||||
yield "", chat_history + [(chat_input, text or "Thinking ...")], refs
|
||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||
yield chat_history + [(chat_input, text or "Thinking ...")], refs, state
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
@@ -208,6 +251,7 @@ class ChatPage(BasePage):
|
||||
|
||||
if "output" in response:
|
||||
text += response["output"]
|
||||
|
||||
if "evidence" in response:
|
||||
if response["evidence"] is None:
|
||||
refs = ""
|
||||
@@ -218,4 +262,25 @@ class ChatPage(BasePage):
|
||||
print(f"Len refs: {len(refs)}")
|
||||
len_ref = len(refs)
|
||||
|
||||
yield "", chat_history + [(chat_input, text)], refs
|
||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||
yield chat_history + [(chat_input, text)], refs, state
|
||||
|
||||
async def regen_fn(
|
||||
self, conversation_id, chat_history, settings, state, *selecteds
|
||||
):
|
||||
"""Regen function"""
|
||||
if not chat_history:
|
||||
gr.Warning("Empty chat")
|
||||
yield chat_history, "", state
|
||||
return
|
||||
|
||||
state["app"]["regen"] = True
|
||||
async for chat, refs, state in self.chat_fn(
|
||||
conversation_id, chat_history, settings, state, *selecteds
|
||||
):
|
||||
new_state = deepcopy(state)
|
||||
new_state["app"]["regen"] = False
|
||||
yield chat, refs, new_state
|
||||
else:
|
||||
state["app"]["regen"] = False
|
||||
yield chat_history, "", state
|
||||
|
@@ -19,6 +19,7 @@ class ChatPanel(BasePage):
|
||||
placeholder="Chat input", scale=15, container=False
|
||||
)
|
||||
self.submit_btn = gr.Button(value="Send", scale=1, min_width=10)
|
||||
self.regen_btn = gr.Button(value="Regen", scale=1, min_width=10)
|
||||
|
||||
def submit_msg(self, chat_input, chat_history):
|
||||
"""Submit a message to the chatbot"""
|
||||
|
4
libs/ktem/ktem/pages/chat/common.py
Normal file
4
libs/ktem/ktem/pages/chat/common.py
Normal file
@@ -0,0 +1,4 @@
|
||||
DEFAULT_APPLICATION_STATE = {"regen": False}
|
||||
STATE = {
|
||||
"app": DEFAULT_APPLICATION_STATE,
|
||||
}
|
@@ -5,6 +5,8 @@ from ktem.app import BasePage
|
||||
from ktem.db.models import Conversation, engine
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .common import STATE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -159,12 +161,14 @@ class ConversationControl(BasePage):
|
||||
name = result.name
|
||||
selected = result.data_source.get("selected", {})
|
||||
chats = result.data_source.get("messages", [])
|
||||
state = result.data_source.get("state", STATE)
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
id_ = ""
|
||||
name = ""
|
||||
selected = {}
|
||||
chats = []
|
||||
state = STATE
|
||||
|
||||
indices = []
|
||||
for index in self._app.index_manager.indices:
|
||||
@@ -173,7 +177,7 @@ class ConversationControl(BasePage):
|
||||
continue
|
||||
indices.append(selected.get(str(index.id), []))
|
||||
|
||||
return id_, id_, name, chats, *indices
|
||||
return id_, id_, name, chats, state, *indices
|
||||
|
||||
def rename_conv(self, conversation_id, new_name, user_id):
|
||||
"""Rename the conversation"""
|
||||
|
@@ -48,6 +48,7 @@ class ReportIssue(BasePage):
|
||||
chat_history: list,
|
||||
settings: dict,
|
||||
user_id: Optional[int],
|
||||
chat_state: dict,
|
||||
*selecteds
|
||||
):
|
||||
selecteds_ = {}
|
||||
@@ -65,6 +66,7 @@ class ReportIssue(BasePage):
|
||||
chat={
|
||||
"conv_id": conv_id,
|
||||
"chat_history": chat_history,
|
||||
"chat_state": chat_state,
|
||||
"selecteds": selecteds_,
|
||||
},
|
||||
settings=settings,
|
||||
|
Reference in New Issue
Block a user