feat: add quick file selection upon tagging on Chat input (#533) bump:patch

* fix: improve inline citation logics without rag

* fix: improve explanation for citation options

* feat: add quick file selection on Chat input
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-11-28 21:12:56 +07:00 committed by GitHub
parent f15abdbb23
commit ab6b3fc529
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 185 additions and 50 deletions

View File

@ -152,6 +152,20 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
def replace_citation_with_link(self, answer: str):
# Define the regex pattern to match 【number】
pattern = r"\d+】"
# Regular expression to match merged citations
multi_pattern = r"【([\d,\s]+)】"
# Function to replace merged citations with independent ones
def split_citations(match):
# Extract the numbers, split by comma, and create individual citations
numbers = match.group(1).split(",")
return "".join(f"{num.strip()}" for num in numbers)
# Replace merged citations in the text
answer = re.sub(multi_pattern, split_citations, answer)
# Find all citations in the answer
matches = re.finditer(pattern, answer)
matched_citations = set()
@ -240,25 +254,30 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
# try streaming first
print("Trying LLM streaming")
for out_msg in self.llm.stream(messages):
if START_ANSWER in output:
if not final_answer:
try:
left_over_answer = output.split(START_ANSWER)[1].lstrip()
except IndexError:
left_over_answer = ""
if left_over_answer:
out_msg.text = left_over_answer + out_msg.text
if evidence:
if START_ANSWER in output:
if not final_answer:
try:
left_over_answer = output.split(START_ANSWER)[
1
].lstrip()
except IndexError:
left_over_answer = ""
if left_over_answer:
out_msg.text = left_over_answer + out_msg.text
final_answer += (
out_msg.text.lstrip() if not final_answer else out_msg.text
)
final_answer += (
out_msg.text.lstrip() if not final_answer else out_msg.text
)
yield Document(channel="chat", content=out_msg.text)
# check for the edge case of citation list is repeated
# with smaller LLMs
if START_CITATION in out_msg.text:
break
else:
yield Document(channel="chat", content=out_msg.text)
# check for the edge case of citation list is repeated
# with smaller LLMs
if START_CITATION in out_msg.text:
break
output += out_msg.text
logprobs += out_msg.logprobs
except NotImplementedError:
@ -289,8 +308,10 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
# yield the final answer
final_answer = self.replace_citation_with_link(final_answer)
yield Document(channel="chat", content=None)
yield Document(channel="chat", content=final_answer)
if final_answer:
yield Document(channel="chat", content=None)
yield Document(channel="chat", content=final_answer)
return answer

View File

@ -26,6 +26,9 @@ def find_start_end_phrase(
matches = []
matched_length = 0
for sentence in [start_phrase, end_phrase]:
if sentence is None:
continue
match = SequenceMatcher(
None, sentence, context, autojunk=False
).find_longest_match()

View File

@ -177,6 +177,10 @@ class BaseApp:
"<script>"
f"{self._svg_js}"
"</script>"
"<script type='module' "
"src='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.min.js'>" # noqa
"</script>"
"<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.css'/>" # noqa
)
with gr.Blocks(

View File

@ -365,3 +365,20 @@ details.evidence {
color: #10b981;
text-decoration: none;
}
/* pop-up for file tag in chat input*/
.tribute-container ul {
background-color: var(--background-fill-primary) !important;
color: var(--body-text-color) !important;
font-family: var(--font);
font-size: var(--text-md);
}
.tribute-container li.highlight {
background-color: var(--border-color-primary) !important;
}
/* a fix for flickering background in Gradio DataFrame */
tbody:not(.row_odd) {
background: var(--table-even-background-fill);
}

View File

@ -29,6 +29,25 @@ function() {
}
"""
update_file_list_js = """
function(file_list) {
var values = [];
for (var i = 0; i < file_list.length; i++) {
values.push({
key: file_list[i][0],
value: '"' + file_list[i][0] + '"',
});
}
var tribute = new Tribute({
values: values,
noMatchTemplate: "",
allowSpaces: true,
})
input_box = document.querySelector('#chat-input textarea');
tribute.attach(input_box);
}
"""
class File(gr.File):
"""Subclass from gr.File to maintain the original filename
@ -1429,6 +1448,10 @@ class FileSelector(BasePage):
visible=False,
)
self.selector_user_id = gr.State(value=user_id)
self.selector_choices = gr.JSON(
value=[],
visible=False,
)
def on_register_events(self):
self.mode.change(
@ -1436,6 +1459,14 @@ class FileSelector(BasePage):
inputs=[self.mode, self._app.user_id],
outputs=[self.selector, self.selector_user_id],
)
# attach special event for the first index
if self._index.id == 1:
self.selector_choices.change(
fn=None,
inputs=[self.selector_choices],
js=update_file_list_js,
show_progress="hidden",
)
def as_gradio_component(self):
return [self.mode, self.selector, self.selector_user_id]
@ -1468,7 +1499,7 @@ class FileSelector(BasePage):
available_ids = []
if user_id is None:
# not signed in
return gr.update(value=selected_files, choices=options)
return gr.update(value=selected_files, choices=options), options
with Session(engine) as session:
# get file list from Source table
@ -1501,13 +1532,13 @@ class FileSelector(BasePage):
each for each in selected_files if each in available_ids_set
]
return gr.update(value=selected_files, choices=options)
return gr.update(value=selected_files, choices=options), options
def _on_app_created(self):
self._app.app.load(
self.load_files,
inputs=[self.selector, self._app.user_id],
outputs=[self.selector],
outputs=[self.selector, self.selector_choices],
)
def on_subscribe_public_events(self):
@ -1516,26 +1547,18 @@ class FileSelector(BasePage):
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"outputs": [self.selector, self.selector_choices],
"show_progress": "hidden",
},
)
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
for event_name in ["onSignIn", "onSignOut"]:
self._app.subscribe_event(
name=event_name,
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector, self.selector_choices],
"show_progress": "hidden",
},
)

View File

@ -8,7 +8,7 @@ import gradio as gr
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from ktem.index.file.ui import File
from ktem.index.file.ui import File, chat_input_focus_js
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
SuggestConvNamePipeline,
)
@ -22,7 +22,7 @@ from theflow.settings import settings as flowsettings
from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from ...utils import SUPPORTED_LANGUAGE_MAP
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex
from .chat_panel import ChatPanel
from .common import STATE
from .control import ConversationControl
@ -113,6 +113,7 @@ class ChatPage(BasePage):
self.state_plot_history = gr.State([])
self.state_plot_panel = gr.State(None)
self.state_follow_up = gr.State(None)
self.first_selector_choices = gr.State(None)
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
self.chat_control = ConversationControl(self._app)
@ -130,6 +131,11 @@ class ChatPage(BasePage):
):
index_ui.render()
gr_index = index_ui.as_gradio_component()
# get the file selector choices for the first index
if index_id == 0:
self.first_selector_choices = index_ui.selector_choices
if gr_index:
if isinstance(gr_index, list):
index.selector = tuple(
@ -272,6 +278,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self.state_follow_up,
self.first_selector_choices,
],
outputs=[
self.chat_panel.text_input,
@ -280,6 +287,9 @@ class ChatPage(BasePage):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.state_follow_up,
# file selector from the first index
self._indices_input[0],
self._indices_input[1],
],
concurrency_limit=20,
show_progress="hidden",
@ -426,6 +436,10 @@ class ChatPage(BasePage):
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
fn=None,
inputs=None,
js=chat_input_focus_js,
)
self.chat_control.btn_del.click(
@ -516,7 +530,12 @@ class ChatPage(BasePage):
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
)
# evidence display on message selection
@ -535,7 +554,12 @@ class ChatPage(BasePage):
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
)
self.chat_control.cb_is_public.change(
@ -585,7 +609,14 @@ class ChatPage(BasePage):
)
def submit_msg(
self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest
self,
chat_input,
chat_history,
user_id,
conv_id,
conv_name,
chat_suggest,
first_selector_choices,
):
"""Submit a message to the chatbot"""
if not chat_input:
@ -593,6 +624,24 @@ class ChatPage(BasePage):
chat_input_text = chat_input.get("text", "")
# get all file names with pattern @"filename" in input_str
file_names, chat_input_text = get_file_names_regex(chat_input_text)
first_selector_choices_map = {
item[0]: item[1] for item in first_selector_choices
}
file_ids = []
if file_names:
for file_name in file_names:
file_id = first_selector_choices_map.get(file_name)
if file_id:
file_ids.append(file_id)
if file_ids:
selector_output = ["select", file_ids]
else:
selector_output = [gr.update(), gr.update()]
# check if regen mode is active
if chat_input_text:
chat_history = chat_history + [(chat_input_text, None)]
@ -620,14 +669,14 @@ class ChatPage(BasePage):
new_conv_name = conv_name
new_chat_suggestion = chat_suggest
return (
return [
{},
chat_history,
new_conv_id,
conv_update,
new_conv_name,
new_chat_suggestion,
)
] + selector_output
def toggle_delete(self, conv_id):
if conv_id:

View File

@ -25,7 +25,7 @@ class ChatPanel(BasePage):
interactive=True,
scale=20,
file_count="multiple",
placeholder="Chat input",
placeholder="Type a message (or tag a file with @filename)",
container=False,
show_label=False,
elem_id="chat-input",

View File

@ -410,7 +410,11 @@ class FullQAPipeline(BaseReasoning):
"name": "Citation style",
"value": "highlight",
"component": "radio",
"choices": ["highlight", "inline", "off"],
"choices": [
("highlight (long answer)", "highlight"),
("inline (precise answer)", "inline"),
("off", "off"),
],
},
"create_mindmap": {
"name": "Create Mindmap",

View File

@ -1,3 +1,4 @@
from .conversation import get_file_names_regex
from .lang import SUPPORTED_LANGUAGE_MAP
__all__ = ["SUPPORTED_LANGUAGE_MAP"]
__all__ = ["SUPPORTED_LANGUAGE_MAP", "get_file_names_regex"]

View File

@ -1,3 +1,6 @@
import re
def sync_retrieval_n_message(
messages: list[list[str]],
retrievals: list[str],
@ -16,5 +19,15 @@ def sync_retrieval_n_message(
return retrievals
def get_file_names_regex(input_str: str) -> tuple[list[str], str]:
# get all file names with pattern @"filename" in input_str
# also remove these file names from input_str
pattern = r'@"([^"]*)"'
matches = re.findall(pattern, input_str)
input_str = re.sub(pattern, "", input_str).strip()
return matches, input_str
if __name__ == "__main__":
print(sync_retrieval_n_message([[""], [""], [""]], []))