Feat/regenerate answer (#7)

* Add regen button and repharasing question on regen

* Stop appending regen messages to history, allow only one

* Add dynamic conversation state

* Allow reasoning pipeline to manipulate state

---------

Co-authored-by: albert <albert@cinnamon.is>
Co-authored-by: Duc Nguyen (john) <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin 2024-04-03 15:37:55 +07:00 committed by GitHub
parent e67a25c0bd
commit 43a18ba070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 151 additions and 24 deletions

View File

@ -15,8 +15,6 @@ from decouple import config
from kotaemon.loaders.utils.gpt4v import generate_gpt4v from kotaemon.loaders.utils.gpt4v import generate_gpt4v
logging.basicConfig(level=os.environ.get("LOGLEVEL", "INFO"))
def request_adobe_service(file_path: str, output_path: str = "") -> str: def request_adobe_service(file_path: str, output_path: str = "") -> str:
"""Main function to call the adobe service, and unzip the results. """Main function to call the adobe service, and unzip the results.

View File

@ -17,6 +17,7 @@ class BaseApp:
The main application contains app-level information: The main application contains app-level information:
- setting state - setting state
- dynamic conversation state
- user id - user id
Also contains registering methods for: Also contains registering methods for:

View File

@ -9,6 +9,7 @@ from ktem.db.models import Conversation, engine
from sqlmodel import Session, select from sqlmodel import Session, select
from .chat_panel import ChatPanel from .chat_panel import ChatPanel
from .common import STATE
from .control import ConversationControl from .control import ConversationControl
from .report import ReportIssue from .report import ReportIssue
@ -21,6 +22,7 @@ class ChatPage(BasePage):
def on_building_ui(self): def on_building_ui(self):
with gr.Row(): with gr.Row():
self.chat_state = gr.State(STATE)
with gr.Column(scale=1): with gr.Column(scale=1):
self.chat_control = ConversationControl(self._app) self.chat_control = ConversationControl(self._app)
@ -62,12 +64,13 @@ class ChatPage(BasePage):
self.chat_control.conversation_id, self.chat_control.conversation_id,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self._app.settings_state, self._app.settings_state,
self.chat_state,
] ]
+ self._indices_input, + self._indices_input,
outputs=[ outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self.info_panel, self.info_panel,
self.chat_state,
], ],
show_progress="minimal", show_progress="minimal",
).then( ).then(
@ -75,6 +78,33 @@ class ChatPage(BasePage):
inputs=[ inputs=[
self.chat_control.conversation_id, self.chat_control.conversation_id,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self.chat_state,
]
+ self._indices_input,
outputs=None,
)
self.chat_panel.regen_btn.click(
fn=self.regen_fn,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self._app.settings_state,
self.chat_state,
]
+ self._indices_input,
outputs=[
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
],
show_progress="minimal",
).then(
fn=self.update_data_source,
inputs=[
self.chat_control.conversation_id,
self.chat_panel.chatbot,
self.chat_state,
] ]
+ self._indices_input, + self._indices_input,
outputs=None, outputs=None,
@ -94,6 +124,7 @@ class ChatPage(BasePage):
self.chat_control.conversation, self.chat_control.conversation,
self.chat_control.conversation_rn, self.chat_control.conversation_rn,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self.chat_state,
] ]
+ self._indices_input, + self._indices_input,
show_progress="hidden", show_progress="hidden",
@ -109,12 +140,13 @@ class ChatPage(BasePage):
self.chat_panel.chatbot, self.chat_panel.chatbot,
self._app.settings_state, self._app.settings_state,
self._app.user_id, self._app.user_id,
self.chat_state,
] ]
+ self._indices_input, + self._indices_input,
outputs=None, outputs=None,
) )
def update_data_source(self, convo_id, messages, *selecteds): def update_data_source(self, convo_id, messages, state, *selecteds):
"""Update the data source""" """Update the data source"""
if not convo_id: if not convo_id:
gr.Warning("No conversation selected") gr.Warning("No conversation selected")
@ -133,6 +165,7 @@ class ChatPage(BasePage):
result.data_source = { result.data_source = {
"selected": selecteds_, "selected": selecteds_,
"messages": messages, "messages": messages,
"state": state,
"likes": deepcopy(data_source.get("likes", [])), "likes": deepcopy(data_source.get("likes", [])),
} }
session.add(result) session.add(result)
@ -152,17 +185,22 @@ class ChatPage(BasePage):
session.add(result) session.add(result)
session.commit() session.commit()
def create_pipeline(self, settings: dict, *selecteds): def create_pipeline(self, settings: dict, state: dict, *selecteds):
"""Create the pipeline from settings """Create the pipeline from settings
Args: Args:
settings: the settings of the app settings: the settings of the app
is_regen: whether the regen button is clicked
selected: the list of file ids that will be served as context. If None, then selected: the list of file ids that will be served as context. If None, then
consider using all files consider using all files
Returns: Returns:
the pipeline objects - the pipeline objects
""" """
reasoning_mode = settings["reasoning.use"]
reasoning_cls = reasonings[reasoning_mode]
reasoning_id = reasoning_cls.get_info()["id"]
# get retrievers # get retrievers
retrievers = [] retrievers = []
for index in self._app.index_manager.indices: for index in self._app.index_manager.indices:
@ -172,13 +210,17 @@ class ChatPage(BasePage):
iretrievers = index.get_retriever_pipelines(settings, index_selected) iretrievers = index.get_retriever_pipelines(settings, index_selected)
retrievers += iretrievers retrievers += iretrievers
reasoning_mode = settings["reasoning.use"] # prepare states
reasoning_cls = reasonings[reasoning_mode] reasoning_state = {
pipeline = reasoning_cls.get_pipeline(settings, retrievers) "app": deepcopy(state["app"]),
"pipeline": deepcopy(state.get(reasoning_id, {})),
}
return pipeline pipeline = reasoning_cls.get_pipeline(settings, reasoning_state, retrievers)
async def chat_fn(self, conversation_id, chat_history, settings, *selecteds): return pipeline, reasoning_state
async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Chat function""" """Chat function"""
chat_input = chat_history[-1][0] chat_input = chat_history[-1][0]
chat_history = chat_history[:-1] chat_history = chat_history[:-1]
@ -186,7 +228,7 @@ class ChatPage(BasePage):
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue() queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
# construct the pipeline # construct the pipeline
pipeline = self.create_pipeline(settings, *selecteds) pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
pipeline.set_output_queue(queue) pipeline.set_output_queue(queue)
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history)) asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
@ -198,7 +240,8 @@ class ChatPage(BasePage):
try: try:
response = queue.get_nowait() response = queue.get_nowait()
except Exception: except Exception:
yield "", chat_history + [(chat_input, text or "Thinking ...")], refs state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text or "Thinking ...")], refs, state
continue continue
if response is None: if response is None:
@ -208,6 +251,7 @@ class ChatPage(BasePage):
if "output" in response: if "output" in response:
text += response["output"] text += response["output"]
if "evidence" in response: if "evidence" in response:
if response["evidence"] is None: if response["evidence"] is None:
refs = "" refs = ""
@ -218,4 +262,25 @@ class ChatPage(BasePage):
print(f"Len refs: {len(refs)}") print(f"Len refs: {len(refs)}")
len_ref = len(refs) len_ref = len(refs)
yield "", chat_history + [(chat_input, text)], refs state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text)], refs, state
async def regen_fn(
self, conversation_id, chat_history, settings, state, *selecteds
):
"""Regen function"""
if not chat_history:
gr.Warning("Empty chat")
yield chat_history, "", state
return
state["app"]["regen"] = True
async for chat, refs, state in self.chat_fn(
conversation_id, chat_history, settings, state, *selecteds
):
new_state = deepcopy(state)
new_state["app"]["regen"] = False
yield chat, refs, new_state
else:
state["app"]["regen"] = False
yield chat_history, "", state

View File

@ -19,6 +19,7 @@ class ChatPanel(BasePage):
placeholder="Chat input", scale=15, container=False placeholder="Chat input", scale=15, container=False
) )
self.submit_btn = gr.Button(value="Send", scale=1, min_width=10) self.submit_btn = gr.Button(value="Send", scale=1, min_width=10)
self.regen_btn = gr.Button(value="Regen", scale=1, min_width=10)
def submit_msg(self, chat_input, chat_history): def submit_msg(self, chat_input, chat_history):
"""Submit a message to the chatbot""" """Submit a message to the chatbot"""

View File

@ -0,0 +1,4 @@
DEFAULT_APPLICATION_STATE = {"regen": False}
STATE = {
"app": DEFAULT_APPLICATION_STATE,
}

View File

@ -5,6 +5,8 @@ from ktem.app import BasePage
from ktem.db.models import Conversation, engine from ktem.db.models import Conversation, engine
from sqlmodel import Session, select from sqlmodel import Session, select
from .common import STATE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -159,12 +161,14 @@ class ConversationControl(BasePage):
name = result.name name = result.name
selected = result.data_source.get("selected", {}) selected = result.data_source.get("selected", {})
chats = result.data_source.get("messages", []) chats = result.data_source.get("messages", [])
state = result.data_source.get("state", STATE)
except Exception as e: except Exception as e:
logger.warning(e) logger.warning(e)
id_ = "" id_ = ""
name = "" name = ""
selected = {} selected = {}
chats = [] chats = []
state = STATE
indices = [] indices = []
for index in self._app.index_manager.indices: for index in self._app.index_manager.indices:
@ -173,7 +177,7 @@ class ConversationControl(BasePage):
continue continue
indices.append(selected.get(str(index.id), [])) indices.append(selected.get(str(index.id), []))
return id_, id_, name, chats, *indices return id_, id_, name, chats, state, *indices
def rename_conv(self, conversation_id, new_name, user_id): def rename_conv(self, conversation_id, new_name, user_id):
"""Rename the conversation""" """Rename the conversation"""

View File

@ -48,6 +48,7 @@ class ReportIssue(BasePage):
chat_history: list, chat_history: list,
settings: dict, settings: dict,
user_id: Optional[int], user_id: Optional[int],
chat_state: dict,
*selecteds *selecteds
): ):
selecteds_ = {} selecteds_ = {}
@ -65,6 +66,7 @@ class ReportIssue(BasePage):
chat={ chat={
"conv_id": conv_id, "conv_id": conv_id,
"chat_history": chat_history, "chat_history": chat_history,
"chat_state": chat_state,
"selecteds": selecteds_, "selecteds": selecteds_,
}, },
settings=settings, settings=settings,

View File

@ -7,7 +7,6 @@ from functools import partial
import tiktoken import tiktoken
from ktem.components import llms from ktem.components import llms
from ktem.reasoning.base import BaseReasoning
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
from kotaemon.base import ( from kotaemon.base import (
@ -164,6 +163,15 @@ DEFAULT_QA_FIGURE_PROMPT = (
"Answer: " "Answer: "
) )
DEFAULT_REWRITE_PROMPT = (
"Given the following question, rephrase and expand it "
"to help you do better answering. Maintain all information "
"in the original question. Keep the question as concise as possible. "
"Give answer in {lang}\n"
"Original question: {question}\n"
"Rephrased question: "
)
class AnswerWithContextPipeline(BaseComponent): class AnswerWithContextPipeline(BaseComponent):
"""Answer the question based on the evidence """Answer the question based on the evidence
@ -287,6 +295,7 @@ class AnswerWithContextPipeline(BaseComponent):
return answer return answer
def extract_evidence_images(self, evidence: str): def extract_evidence_images(self, evidence: str):
"""Util function to extract and isolate images from context/evidence""" """Util function to extract and isolate images from context/evidence"""
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
@ -295,7 +304,39 @@ class AnswerWithContextPipeline(BaseComponent):
return context, matches return context, matches
class FullQAPipeline(BaseReasoning): class RewriteQuestionPipeline(BaseComponent):
"""Rewrite user question
Args:
llm: the language model to rewrite question
rewrite_template: the prompt template for llm to paraphrase a text input
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
async def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
output = ""
for text in self.llm(messages):
if "content" in text:
output += text[1]
self.report_output({"chat_input": text[1]})
break
await asyncio.sleep(0)
return Document(text=output)
class FullQAPipeline(BaseComponent):
"""Question answering pipeline. Handle from question to answer""" """Question answering pipeline. Handle from question to answer"""
class Config: class Config:
@ -305,12 +346,18 @@ class FullQAPipeline(BaseReasoning):
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False
async def run( # type: ignore async def run( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore ) -> Document: # type: ignore
docs = [] docs = []
doc_ids = [] doc_ids = []
if self.use_rewrite:
rewrite = await self.rewrite_pipeline(question=message)
message = rewrite.text
for retriever in self.retrievers: for retriever in self.retrievers:
for doc in retriever(text=message): for doc in retriever(text=message):
if doc.doc_id not in doc_ids: if doc.doc_id not in doc_ids:
@ -402,7 +449,7 @@ class FullQAPipeline(BaseReasoning):
return answer return answer
@classmethod @classmethod
def get_pipeline(cls, settings, retrievers): def get_pipeline(cls, settings, states, retrievers):
"""Get the reasoning pipeline """Get the reasoning pipeline
Args: Args:
@ -430,6 +477,11 @@ class FullQAPipeline(BaseReasoning):
pipeline.answering_pipeline.qa_template = settings[ pipeline.answering_pipeline.qa_template = settings[
f"reasoning.options.{_id}.qa_prompt" f"reasoning.options.{_id}.qa_prompt"
] ]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.rewrite_pipeline.llm = llms.get_lowest_cost()
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
return pipeline return pipeline
@classmethod @classmethod