feat: add quick file selection upon tagging on Chat input (#533) bump:patch

* fix: improve inline citation logics without rag

* fix: improve explanation for citation options

* feat: add quick file selection on Chat input
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-11-28 21:12:56 +07:00 committed by GitHub
parent f15abdbb23
commit ab6b3fc529
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 185 additions and 50 deletions

View File

@ -152,6 +152,20 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
def replace_citation_with_link(self, answer: str): def replace_citation_with_link(self, answer: str):
# Define the regex pattern to match 【number】 # Define the regex pattern to match 【number】
pattern = r"\d+】" pattern = r"\d+】"
# Regular expression to match merged citations
multi_pattern = r"【([\d,\s]+)】"
# Function to replace merged citations with independent ones
def split_citations(match):
# Extract the numbers, split by comma, and create individual citations
numbers = match.group(1).split(",")
return "".join(f"{num.strip()}" for num in numbers)
# Replace merged citations in the text
answer = re.sub(multi_pattern, split_citations, answer)
# Find all citations in the answer
matches = re.finditer(pattern, answer) matches = re.finditer(pattern, answer)
matched_citations = set() matched_citations = set()
@ -240,10 +254,13 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
# try streaming first # try streaming first
print("Trying LLM streaming") print("Trying LLM streaming")
for out_msg in self.llm.stream(messages): for out_msg in self.llm.stream(messages):
if evidence:
if START_ANSWER in output: if START_ANSWER in output:
if not final_answer: if not final_answer:
try: try:
left_over_answer = output.split(START_ANSWER)[1].lstrip() left_over_answer = output.split(START_ANSWER)[
1
].lstrip()
except IndexError: except IndexError:
left_over_answer = "" left_over_answer = ""
if left_over_answer: if left_over_answer:
@ -258,6 +275,8 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
# with smaller LLMs # with smaller LLMs
if START_CITATION in out_msg.text: if START_CITATION in out_msg.text:
break break
else:
yield Document(channel="chat", content=out_msg.text)
output += out_msg.text output += out_msg.text
logprobs += out_msg.logprobs logprobs += out_msg.logprobs
@ -289,6 +308,8 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
# yield the final answer # yield the final answer
final_answer = self.replace_citation_with_link(final_answer) final_answer = self.replace_citation_with_link(final_answer)
if final_answer:
yield Document(channel="chat", content=None) yield Document(channel="chat", content=None)
yield Document(channel="chat", content=final_answer) yield Document(channel="chat", content=final_answer)

View File

@ -26,6 +26,9 @@ def find_start_end_phrase(
matches = [] matches = []
matched_length = 0 matched_length = 0
for sentence in [start_phrase, end_phrase]: for sentence in [start_phrase, end_phrase]:
if sentence is None:
continue
match = SequenceMatcher( match = SequenceMatcher(
None, sentence, context, autojunk=False None, sentence, context, autojunk=False
).find_longest_match() ).find_longest_match()

View File

@ -177,6 +177,10 @@ class BaseApp:
"<script>" "<script>"
f"{self._svg_js}" f"{self._svg_js}"
"</script>" "</script>"
"<script type='module' "
"src='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.min.js'>" # noqa
"</script>"
"<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.css'/>" # noqa
) )
with gr.Blocks( with gr.Blocks(

View File

@ -365,3 +365,20 @@ details.evidence {
color: #10b981; color: #10b981;
text-decoration: none; text-decoration: none;
} }
/* pop-up for file tag in chat input*/
.tribute-container ul {
background-color: var(--background-fill-primary) !important;
color: var(--body-text-color) !important;
font-family: var(--font);
font-size: var(--text-md);
}
.tribute-container li.highlight {
background-color: var(--border-color-primary) !important;
}
/* a fix for flickering background in Gradio DataFrame */
tbody:not(.row_odd) {
background: var(--table-even-background-fill);
}

View File

@ -29,6 +29,25 @@ function() {
} }
""" """
update_file_list_js = """
function(file_list) {
var values = [];
for (var i = 0; i < file_list.length; i++) {
values.push({
key: file_list[i][0],
value: '"' + file_list[i][0] + '"',
});
}
var tribute = new Tribute({
values: values,
noMatchTemplate: "",
allowSpaces: true,
})
input_box = document.querySelector('#chat-input textarea');
tribute.attach(input_box);
}
"""
class File(gr.File): class File(gr.File):
"""Subclass from gr.File to maintain the original filename """Subclass from gr.File to maintain the original filename
@ -1429,6 +1448,10 @@ class FileSelector(BasePage):
visible=False, visible=False,
) )
self.selector_user_id = gr.State(value=user_id) self.selector_user_id = gr.State(value=user_id)
self.selector_choices = gr.JSON(
value=[],
visible=False,
)
def on_register_events(self): def on_register_events(self):
self.mode.change( self.mode.change(
@ -1436,6 +1459,14 @@ class FileSelector(BasePage):
inputs=[self.mode, self._app.user_id], inputs=[self.mode, self._app.user_id],
outputs=[self.selector, self.selector_user_id], outputs=[self.selector, self.selector_user_id],
) )
# attach special event for the first index
if self._index.id == 1:
self.selector_choices.change(
fn=None,
inputs=[self.selector_choices],
js=update_file_list_js,
show_progress="hidden",
)
def as_gradio_component(self): def as_gradio_component(self):
return [self.mode, self.selector, self.selector_user_id] return [self.mode, self.selector, self.selector_user_id]
@ -1468,7 +1499,7 @@ class FileSelector(BasePage):
available_ids = [] available_ids = []
if user_id is None: if user_id is None:
# not signed in # not signed in
return gr.update(value=selected_files, choices=options) return gr.update(value=selected_files, choices=options), options
with Session(engine) as session: with Session(engine) as session:
# get file list from Source table # get file list from Source table
@ -1501,13 +1532,13 @@ class FileSelector(BasePage):
each for each in selected_files if each in available_ids_set each for each in selected_files if each in available_ids_set
] ]
return gr.update(value=selected_files, choices=options) return gr.update(value=selected_files, choices=options), options
def _on_app_created(self): def _on_app_created(self):
self._app.app.load( self._app.app.load(
self.load_files, self.load_files,
inputs=[self.selector, self._app.user_id], inputs=[self.selector, self._app.user_id],
outputs=[self.selector], outputs=[self.selector, self.selector_choices],
) )
def on_subscribe_public_events(self): def on_subscribe_public_events(self):
@ -1516,26 +1547,18 @@ class FileSelector(BasePage):
definition={ definition={
"fn": self.load_files, "fn": self.load_files,
"inputs": [self.selector, self._app.user_id], "inputs": [self.selector, self._app.user_id],
"outputs": [self.selector], "outputs": [self.selector, self.selector_choices],
"show_progress": "hidden", "show_progress": "hidden",
}, },
) )
if self._app.f_user_management: if self._app.f_user_management:
for event_name in ["onSignIn", "onSignOut"]:
self._app.subscribe_event( self._app.subscribe_event(
name="onSignIn", name=event_name,
definition={ definition={
"fn": self.load_files, "fn": self.load_files,
"inputs": [self.selector, self._app.user_id], "inputs": [self.selector, self._app.user_id],
"outputs": [self.selector], "outputs": [self.selector, self.selector_choices],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden", "show_progress": "hidden",
}, },
) )

View File

@ -8,7 +8,7 @@ import gradio as gr
from ktem.app import BasePage from ktem.app import BasePage
from ktem.components import reasonings from ktem.components import reasonings
from ktem.db.models import Conversation, engine from ktem.db.models import Conversation, engine
from ktem.index.file.ui import File from ktem.index.file.ui import File, chat_input_focus_js
from ktem.reasoning.prompt_optimization.suggest_conversation_name import ( from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
SuggestConvNamePipeline, SuggestConvNamePipeline,
) )
@ -22,7 +22,7 @@ from theflow.settings import settings as flowsettings
from kotaemon.base import Document from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from ...utils import SUPPORTED_LANGUAGE_MAP from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex
from .chat_panel import ChatPanel from .chat_panel import ChatPanel
from .common import STATE from .common import STATE
from .control import ConversationControl from .control import ConversationControl
@ -113,6 +113,7 @@ class ChatPage(BasePage):
self.state_plot_history = gr.State([]) self.state_plot_history = gr.State([])
self.state_plot_panel = gr.State(None) self.state_plot_panel = gr.State(None)
self.state_follow_up = gr.State(None) self.state_follow_up = gr.State(None)
self.first_selector_choices = gr.State(None)
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column: with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
self.chat_control = ConversationControl(self._app) self.chat_control = ConversationControl(self._app)
@ -130,6 +131,11 @@ class ChatPage(BasePage):
): ):
index_ui.render() index_ui.render()
gr_index = index_ui.as_gradio_component() gr_index = index_ui.as_gradio_component()
# get the file selector choices for the first index
if index_id == 0:
self.first_selector_choices = index_ui.selector_choices
if gr_index: if gr_index:
if isinstance(gr_index, list): if isinstance(gr_index, list):
index.selector = tuple( index.selector = tuple(
@ -272,6 +278,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_id, self.chat_control.conversation_id,
self.chat_control.conversation_rn, self.chat_control.conversation_rn,
self.state_follow_up, self.state_follow_up,
self.first_selector_choices,
], ],
outputs=[ outputs=[
self.chat_panel.text_input, self.chat_panel.text_input,
@ -280,6 +287,9 @@ class ChatPage(BasePage):
self.chat_control.conversation, self.chat_control.conversation,
self.chat_control.conversation_rn, self.chat_control.conversation_rn,
self.state_follow_up, self.state_follow_up,
# file selector from the first index
self._indices_input[0],
self._indices_input[1],
], ],
concurrency_limit=20, concurrency_limit=20,
show_progress="hidden", show_progress="hidden",
@ -426,6 +436,10 @@ class ChatPage(BasePage):
fn=self._json_to_plot, fn=self._json_to_plot,
inputs=self.state_plot_panel, inputs=self.state_plot_panel,
outputs=self.plot_panel, outputs=self.plot_panel,
).then(
fn=None,
inputs=None,
js=chat_input_focus_js,
) )
self.chat_control.btn_del.click( self.chat_control.btn_del.click(
@ -516,7 +530,12 @@ class ChatPage(BasePage):
lambda: self.toggle_delete(""), lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
).then( ).then(
fn=None, inputs=None, outputs=None, js=pdfview_js fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
) )
# evidence display on message selection # evidence display on message selection
@ -535,7 +554,12 @@ class ChatPage(BasePage):
inputs=self.state_plot_panel, inputs=self.state_plot_panel,
outputs=self.plot_panel, outputs=self.plot_panel,
).then( ).then(
fn=None, inputs=None, outputs=None, js=pdfview_js fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
) )
self.chat_control.cb_is_public.change( self.chat_control.cb_is_public.change(
@ -585,7 +609,14 @@ class ChatPage(BasePage):
) )
def submit_msg( def submit_msg(
self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest self,
chat_input,
chat_history,
user_id,
conv_id,
conv_name,
chat_suggest,
first_selector_choices,
): ):
"""Submit a message to the chatbot""" """Submit a message to the chatbot"""
if not chat_input: if not chat_input:
@ -593,6 +624,24 @@ class ChatPage(BasePage):
chat_input_text = chat_input.get("text", "") chat_input_text = chat_input.get("text", "")
# get all file names with pattern @"filename" in input_str
file_names, chat_input_text = get_file_names_regex(chat_input_text)
first_selector_choices_map = {
item[0]: item[1] for item in first_selector_choices
}
file_ids = []
if file_names:
for file_name in file_names:
file_id = first_selector_choices_map.get(file_name)
if file_id:
file_ids.append(file_id)
if file_ids:
selector_output = ["select", file_ids]
else:
selector_output = [gr.update(), gr.update()]
# check if regen mode is active # check if regen mode is active
if chat_input_text: if chat_input_text:
chat_history = chat_history + [(chat_input_text, None)] chat_history = chat_history + [(chat_input_text, None)]
@ -620,14 +669,14 @@ class ChatPage(BasePage):
new_conv_name = conv_name new_conv_name = conv_name
new_chat_suggestion = chat_suggest new_chat_suggestion = chat_suggest
return ( return [
{}, {},
chat_history, chat_history,
new_conv_id, new_conv_id,
conv_update, conv_update,
new_conv_name, new_conv_name,
new_chat_suggestion, new_chat_suggestion,
) ] + selector_output
def toggle_delete(self, conv_id): def toggle_delete(self, conv_id):
if conv_id: if conv_id:

View File

@ -25,7 +25,7 @@ class ChatPanel(BasePage):
interactive=True, interactive=True,
scale=20, scale=20,
file_count="multiple", file_count="multiple",
placeholder="Chat input", placeholder="Type a message (or tag a file with @filename)",
container=False, container=False,
show_label=False, show_label=False,
elem_id="chat-input", elem_id="chat-input",

View File

@ -410,7 +410,11 @@ class FullQAPipeline(BaseReasoning):
"name": "Citation style", "name": "Citation style",
"value": "highlight", "value": "highlight",
"component": "radio", "component": "radio",
"choices": ["highlight", "inline", "off"], "choices": [
("highlight (long answer)", "highlight"),
("inline (precise answer)", "inline"),
("off", "off"),
],
}, },
"create_mindmap": { "create_mindmap": {
"name": "Create Mindmap", "name": "Create Mindmap",

View File

@ -1,3 +1,4 @@
from .conversation import get_file_names_regex
from .lang import SUPPORTED_LANGUAGE_MAP from .lang import SUPPORTED_LANGUAGE_MAP
__all__ = ["SUPPORTED_LANGUAGE_MAP"] __all__ = ["SUPPORTED_LANGUAGE_MAP", "get_file_names_regex"]

View File

@ -1,3 +1,6 @@
import re
def sync_retrieval_n_message( def sync_retrieval_n_message(
messages: list[list[str]], messages: list[list[str]],
retrievals: list[str], retrievals: list[str],
@ -16,5 +19,15 @@ def sync_retrieval_n_message(
return retrievals return retrievals
def get_file_names_regex(input_str: str) -> tuple[list[str], str]:
# get all file names with pattern @"filename" in input_str
# also remove these file names from input_str
pattern = r'@"([^"]*)"'
matches = re.findall(pattern, input_str)
input_str = re.sub(pattern, "", input_str).strip()
return matches, input_str
if __name__ == "__main__": if __name__ == "__main__":
print(sync_retrieval_n_message([[""], [""], [""]], [])) print(sync_retrieval_n_message([[""], [""], [""]], []))