Feat/regenerate answer (#7)
* Add regen button and repharasing question on regen * Stop appending regen messages to history, allow only one * Add dynamic conversation state * Allow reasoning pipeline to manipulate state --------- Co-authored-by: albert <albert@cinnamon.is> Co-authored-by: Duc Nguyen (john) <trungduc1992@gmail.com>
This commit is contained in:
parent
e67a25c0bd
commit
43a18ba070
|
@ -15,8 +15,6 @@ from decouple import config
|
||||||
|
|
||||||
from kotaemon.loaders.utils.gpt4v import generate_gpt4v
|
from kotaemon.loaders.utils.gpt4v import generate_gpt4v
|
||||||
|
|
||||||
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
|
|
||||||
|
|
||||||
|
|
||||||
def request_adobe_service(file_path: str, output_path: str = "") -> str:
|
def request_adobe_service(file_path: str, output_path: str = "") -> str:
|
||||||
"""Main function to call the adobe service, and unzip the results.
|
"""Main function to call the adobe service, and unzip the results.
|
||||||
|
|
|
@ -17,6 +17,7 @@ class BaseApp:
|
||||||
|
|
||||||
The main application contains app-level information:
|
The main application contains app-level information:
|
||||||
- setting state
|
- setting state
|
||||||
|
- dynamic conversation state
|
||||||
- user id
|
- user id
|
||||||
|
|
||||||
Also contains registering methods for:
|
Also contains registering methods for:
|
||||||
|
|
|
@ -9,6 +9,7 @@ from ktem.db.models import Conversation, engine
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
from .chat_panel import ChatPanel
|
from .chat_panel import ChatPanel
|
||||||
|
from .common import STATE
|
||||||
from .control import ConversationControl
|
from .control import ConversationControl
|
||||||
from .report import ReportIssue
|
from .report import ReportIssue
|
||||||
|
|
||||||
|
@ -21,6 +22,7 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
def on_building_ui(self):
|
def on_building_ui(self):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
self.chat_state = gr.State(STATE)
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
self.chat_control = ConversationControl(self._app)
|
self.chat_control = ConversationControl(self._app)
|
||||||
|
|
||||||
|
@ -62,12 +64,13 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
|
self.chat_state,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=[
|
outputs=[
|
||||||
self.chat_panel.text_input,
|
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
|
self.chat_state,
|
||||||
],
|
],
|
||||||
show_progress="minimal",
|
show_progress="minimal",
|
||||||
).then(
|
).then(
|
||||||
|
@ -75,6 +78,33 @@ class ChatPage(BasePage):
|
||||||
inputs=[
|
inputs=[
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
|
self.chat_state,
|
||||||
|
]
|
||||||
|
+ self._indices_input,
|
||||||
|
outputs=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chat_panel.regen_btn.click(
|
||||||
|
fn=self.regen_fn,
|
||||||
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
self._app.settings_state,
|
||||||
|
self.chat_state,
|
||||||
|
]
|
||||||
|
+ self._indices_input,
|
||||||
|
outputs=[
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
self.info_panel,
|
||||||
|
self.chat_state,
|
||||||
|
],
|
||||||
|
show_progress="minimal",
|
||||||
|
).then(
|
||||||
|
fn=self.update_data_source,
|
||||||
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
self.chat_state,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=None,
|
outputs=None,
|
||||||
|
@ -94,6 +124,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
|
self.chat_state,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
@ -109,12 +140,13 @@ class ChatPage(BasePage):
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
|
self.chat_state,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=None,
|
outputs=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_data_source(self, convo_id, messages, *selecteds):
|
def update_data_source(self, convo_id, messages, state, *selecteds):
|
||||||
"""Update the data source"""
|
"""Update the data source"""
|
||||||
if not convo_id:
|
if not convo_id:
|
||||||
gr.Warning("No conversation selected")
|
gr.Warning("No conversation selected")
|
||||||
|
@ -133,6 +165,7 @@ class ChatPage(BasePage):
|
||||||
result.data_source = {
|
result.data_source = {
|
||||||
"selected": selecteds_,
|
"selected": selecteds_,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
"state": state,
|
||||||
"likes": deepcopy(data_source.get("likes", [])),
|
"likes": deepcopy(data_source.get("likes", [])),
|
||||||
}
|
}
|
||||||
session.add(result)
|
session.add(result)
|
||||||
|
@ -152,17 +185,22 @@ class ChatPage(BasePage):
|
||||||
session.add(result)
|
session.add(result)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def create_pipeline(self, settings: dict, *selecteds):
|
def create_pipeline(self, settings: dict, state: dict, *selecteds):
|
||||||
"""Create the pipeline from settings
|
"""Create the pipeline from settings
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
settings: the settings of the app
|
settings: the settings of the app
|
||||||
|
is_regen: whether the regen button is clicked
|
||||||
selected: the list of file ids that will be served as context. If None, then
|
selected: the list of file ids that will be served as context. If None, then
|
||||||
consider using all files
|
consider using all files
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
the pipeline objects
|
- the pipeline objects
|
||||||
"""
|
"""
|
||||||
|
reasoning_mode = settings["reasoning.use"]
|
||||||
|
reasoning_cls = reasonings[reasoning_mode]
|
||||||
|
reasoning_id = reasoning_cls.get_info()["id"]
|
||||||
|
|
||||||
# get retrievers
|
# get retrievers
|
||||||
retrievers = []
|
retrievers = []
|
||||||
for index in self._app.index_manager.indices:
|
for index in self._app.index_manager.indices:
|
||||||
|
@ -172,13 +210,17 @@ class ChatPage(BasePage):
|
||||||
iretrievers = index.get_retriever_pipelines(settings, index_selected)
|
iretrievers = index.get_retriever_pipelines(settings, index_selected)
|
||||||
retrievers += iretrievers
|
retrievers += iretrievers
|
||||||
|
|
||||||
reasoning_mode = settings["reasoning.use"]
|
# prepare states
|
||||||
reasoning_cls = reasonings[reasoning_mode]
|
reasoning_state = {
|
||||||
pipeline = reasoning_cls.get_pipeline(settings, retrievers)
|
"app": deepcopy(state["app"]),
|
||||||
|
"pipeline": deepcopy(state.get(reasoning_id, {})),
|
||||||
|
}
|
||||||
|
|
||||||
return pipeline
|
pipeline = reasoning_cls.get_pipeline(settings, reasoning_state, retrievers)
|
||||||
|
|
||||||
async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
|
return pipeline, reasoning_state
|
||||||
|
|
||||||
|
async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||||
"""Chat function"""
|
"""Chat function"""
|
||||||
chat_input = chat_history[-1][0]
|
chat_input = chat_history[-1][0]
|
||||||
chat_history = chat_history[:-1]
|
chat_history = chat_history[:-1]
|
||||||
|
@ -186,7 +228,7 @@ class ChatPage(BasePage):
|
||||||
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
||||||
|
|
||||||
# construct the pipeline
|
# construct the pipeline
|
||||||
pipeline = self.create_pipeline(settings, *selecteds)
|
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
|
||||||
pipeline.set_output_queue(queue)
|
pipeline.set_output_queue(queue)
|
||||||
|
|
||||||
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
||||||
|
@ -198,7 +240,8 @@ class ChatPage(BasePage):
|
||||||
try:
|
try:
|
||||||
response = queue.get_nowait()
|
response = queue.get_nowait()
|
||||||
except Exception:
|
except Exception:
|
||||||
yield "", chat_history + [(chat_input, text or "Thinking ...")], refs
|
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||||
|
yield chat_history + [(chat_input, text or "Thinking ...")], refs, state
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if response is None:
|
if response is None:
|
||||||
|
@ -208,6 +251,7 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
if "output" in response:
|
if "output" in response:
|
||||||
text += response["output"]
|
text += response["output"]
|
||||||
|
|
||||||
if "evidence" in response:
|
if "evidence" in response:
|
||||||
if response["evidence"] is None:
|
if response["evidence"] is None:
|
||||||
refs = ""
|
refs = ""
|
||||||
|
@ -218,4 +262,25 @@ class ChatPage(BasePage):
|
||||||
print(f"Len refs: {len(refs)}")
|
print(f"Len refs: {len(refs)}")
|
||||||
len_ref = len(refs)
|
len_ref = len(refs)
|
||||||
|
|
||||||
yield "", chat_history + [(chat_input, text)], refs
|
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||||
|
yield chat_history + [(chat_input, text)], refs, state
|
||||||
|
|
||||||
|
async def regen_fn(
|
||||||
|
self, conversation_id, chat_history, settings, state, *selecteds
|
||||||
|
):
|
||||||
|
"""Regen function"""
|
||||||
|
if not chat_history:
|
||||||
|
gr.Warning("Empty chat")
|
||||||
|
yield chat_history, "", state
|
||||||
|
return
|
||||||
|
|
||||||
|
state["app"]["regen"] = True
|
||||||
|
async for chat, refs, state in self.chat_fn(
|
||||||
|
conversation_id, chat_history, settings, state, *selecteds
|
||||||
|
):
|
||||||
|
new_state = deepcopy(state)
|
||||||
|
new_state["app"]["regen"] = False
|
||||||
|
yield chat, refs, new_state
|
||||||
|
else:
|
||||||
|
state["app"]["regen"] = False
|
||||||
|
yield chat_history, "", state
|
||||||
|
|
|
@ -19,6 +19,7 @@ class ChatPanel(BasePage):
|
||||||
placeholder="Chat input", scale=15, container=False
|
placeholder="Chat input", scale=15, container=False
|
||||||
)
|
)
|
||||||
self.submit_btn = gr.Button(value="Send", scale=1, min_width=10)
|
self.submit_btn = gr.Button(value="Send", scale=1, min_width=10)
|
||||||
|
self.regen_btn = gr.Button(value="Regen", scale=1, min_width=10)
|
||||||
|
|
||||||
def submit_msg(self, chat_input, chat_history):
|
def submit_msg(self, chat_input, chat_history):
|
||||||
"""Submit a message to the chatbot"""
|
"""Submit a message to the chatbot"""
|
||||||
|
|
4
libs/ktem/ktem/pages/chat/common.py
Normal file
4
libs/ktem/ktem/pages/chat/common.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
DEFAULT_APPLICATION_STATE = {"regen": False}
|
||||||
|
STATE = {
|
||||||
|
"app": DEFAULT_APPLICATION_STATE,
|
||||||
|
}
|
|
@ -5,6 +5,8 @@ from ktem.app import BasePage
|
||||||
from ktem.db.models import Conversation, engine
|
from ktem.db.models import Conversation, engine
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
|
||||||
|
from .common import STATE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,12 +161,14 @@ class ConversationControl(BasePage):
|
||||||
name = result.name
|
name = result.name
|
||||||
selected = result.data_source.get("selected", {})
|
selected = result.data_source.get("selected", {})
|
||||||
chats = result.data_source.get("messages", [])
|
chats = result.data_source.get("messages", [])
|
||||||
|
state = result.data_source.get("state", STATE)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(e)
|
logger.warning(e)
|
||||||
id_ = ""
|
id_ = ""
|
||||||
name = ""
|
name = ""
|
||||||
selected = {}
|
selected = {}
|
||||||
chats = []
|
chats = []
|
||||||
|
state = STATE
|
||||||
|
|
||||||
indices = []
|
indices = []
|
||||||
for index in self._app.index_manager.indices:
|
for index in self._app.index_manager.indices:
|
||||||
|
@ -173,7 +177,7 @@ class ConversationControl(BasePage):
|
||||||
continue
|
continue
|
||||||
indices.append(selected.get(str(index.id), []))
|
indices.append(selected.get(str(index.id), []))
|
||||||
|
|
||||||
return id_, id_, name, chats, *indices
|
return id_, id_, name, chats, state, *indices
|
||||||
|
|
||||||
def rename_conv(self, conversation_id, new_name, user_id):
|
def rename_conv(self, conversation_id, new_name, user_id):
|
||||||
"""Rename the conversation"""
|
"""Rename the conversation"""
|
||||||
|
|
|
@ -48,6 +48,7 @@ class ReportIssue(BasePage):
|
||||||
chat_history: list,
|
chat_history: list,
|
||||||
settings: dict,
|
settings: dict,
|
||||||
user_id: Optional[int],
|
user_id: Optional[int],
|
||||||
|
chat_state: dict,
|
||||||
*selecteds
|
*selecteds
|
||||||
):
|
):
|
||||||
selecteds_ = {}
|
selecteds_ = {}
|
||||||
|
@ -65,6 +66,7 @@ class ReportIssue(BasePage):
|
||||||
chat={
|
chat={
|
||||||
"conv_id": conv_id,
|
"conv_id": conv_id,
|
||||||
"chat_history": chat_history,
|
"chat_history": chat_history,
|
||||||
|
"chat_state": chat_state,
|
||||||
"selecteds": selecteds_,
|
"selecteds": selecteds_,
|
||||||
},
|
},
|
||||||
settings=settings,
|
settings=settings,
|
||||||
|
|
|
@ -7,7 +7,6 @@ from functools import partial
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from ktem.components import llms
|
from ktem.components import llms
|
||||||
from ktem.reasoning.base import BaseReasoning
|
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
from kotaemon.base import (
|
from kotaemon.base import (
|
||||||
|
@ -164,6 +163,15 @@ DEFAULT_QA_FIGURE_PROMPT = (
|
||||||
"Answer: "
|
"Answer: "
|
||||||
)
|
)
|
||||||
|
|
||||||
|
DEFAULT_REWRITE_PROMPT = (
|
||||||
|
"Given the following question, rephrase and expand it "
|
||||||
|
"to help you do better answering. Maintain all information "
|
||||||
|
"in the original question. Keep the question as concise as possible. "
|
||||||
|
"Give answer in {lang}\n"
|
||||||
|
"Original question: {question}\n"
|
||||||
|
"Rephrased question: "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AnswerWithContextPipeline(BaseComponent):
|
class AnswerWithContextPipeline(BaseComponent):
|
||||||
"""Answer the question based on the evidence
|
"""Answer the question based on the evidence
|
||||||
|
@ -287,15 +295,48 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
def extract_evidence_images(self, evidence: str):
|
|
||||||
"""Util function to extract and isolate images from context/evidence"""
|
def extract_evidence_images(self, evidence: str):
|
||||||
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
"""Util function to extract and isolate images from context/evidence"""
|
||||||
matches = re.findall(image_pattern, evidence)
|
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
||||||
context = re.sub(image_pattern, "", evidence)
|
matches = re.findall(image_pattern, evidence)
|
||||||
return context, matches
|
context = re.sub(image_pattern, "", evidence)
|
||||||
|
return context, matches
|
||||||
|
|
||||||
|
|
||||||
class FullQAPipeline(BaseReasoning):
|
class RewriteQuestionPipeline(BaseComponent):
|
||||||
|
"""Rewrite user question
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: the language model to rewrite question
|
||||||
|
rewrite_template: the prompt template for llm to paraphrase a text input
|
||||||
|
lang: the language of the answer. Currently support English and Japanese
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost())
|
||||||
|
rewrite_template: str = DEFAULT_REWRITE_PROMPT
|
||||||
|
|
||||||
|
lang: str = "English"
|
||||||
|
|
||||||
|
async def run(self, question: str) -> Document: # type: ignore
|
||||||
|
prompt_template = PromptTemplate(self.rewrite_template)
|
||||||
|
prompt = prompt_template.populate(question=question, lang=self.lang)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
|
HumanMessage(content=prompt),
|
||||||
|
]
|
||||||
|
output = ""
|
||||||
|
for text in self.llm(messages):
|
||||||
|
if "content" in text:
|
||||||
|
output += text[1]
|
||||||
|
self.report_output({"chat_input": text[1]})
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
return Document(text=output)
|
||||||
|
|
||||||
|
|
||||||
|
class FullQAPipeline(BaseComponent):
|
||||||
"""Question answering pipeline. Handle from question to answer"""
|
"""Question answering pipeline. Handle from question to answer"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -305,12 +346,18 @@ class FullQAPipeline(BaseReasoning):
|
||||||
|
|
||||||
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
||||||
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
||||||
|
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
||||||
|
use_rewrite: bool = False
|
||||||
|
|
||||||
async def run( # type: ignore
|
async def run( # type: ignore
|
||||||
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
||||||
) -> Document: # type: ignore
|
) -> Document: # type: ignore
|
||||||
docs = []
|
docs = []
|
||||||
doc_ids = []
|
doc_ids = []
|
||||||
|
if self.use_rewrite:
|
||||||
|
rewrite = await self.rewrite_pipeline(question=message)
|
||||||
|
message = rewrite.text
|
||||||
|
|
||||||
for retriever in self.retrievers:
|
for retriever in self.retrievers:
|
||||||
for doc in retriever(text=message):
|
for doc in retriever(text=message):
|
||||||
if doc.doc_id not in doc_ids:
|
if doc.doc_id not in doc_ids:
|
||||||
|
@ -402,7 +449,7 @@ class FullQAPipeline(BaseReasoning):
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pipeline(cls, settings, retrievers):
|
def get_pipeline(cls, settings, states, retrievers):
|
||||||
"""Get the reasoning pipeline
|
"""Get the reasoning pipeline
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -430,6 +477,11 @@ class FullQAPipeline(BaseReasoning):
|
||||||
pipeline.answering_pipeline.qa_template = settings[
|
pipeline.answering_pipeline.qa_template = settings[
|
||||||
f"reasoning.options.{_id}.qa_prompt"
|
f"reasoning.options.{_id}.qa_prompt"
|
||||||
]
|
]
|
||||||
|
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||||
|
pipeline.rewrite_pipeline.llm = llms.get_lowest_cost()
|
||||||
|
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||||
|
settings["reasoning.lang"], "English"
|
||||||
|
)
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue
Block a user