feat: add quick file selection upon tagging on Chat input (#533) bump:patch

* fix: improve inline citation logics without rag

* fix: improve explanation for citation options

* feat: add quick file selection on Chat input
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2024-11-28 21:12:56 +07:00
committed by GitHub
parent f15abdbb23
commit ab6b3fc529
10 changed files with 185 additions and 50 deletions

View File

@@ -177,6 +177,10 @@ class BaseApp:
"<script>"
f"{self._svg_js}"
"</script>"
"<script type='module' "
"src='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.min.js'>" # noqa
"</script>"
"<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/tributejs/5.1.3/tribute.css'/>" # noqa
)
with gr.Blocks(

View File

@@ -365,3 +365,20 @@ details.evidence {
color: #10b981;
text-decoration: none;
}
/* pop-up for file tag in chat input*/
.tribute-container ul {
background-color: var(--background-fill-primary) !important;
color: var(--body-text-color) !important;
font-family: var(--font);
font-size: var(--text-md);
}
.tribute-container li.highlight {
background-color: var(--border-color-primary) !important;
}
/* a fix for flickering background in Gradio DataFrame */
tbody:not(.row_odd) {
background: var(--table-even-background-fill);
}

View File

@@ -29,6 +29,25 @@ function() {
}
"""
update_file_list_js = """
function(file_list) {
var values = [];
for (var i = 0; i < file_list.length; i++) {
values.push({
key: file_list[i][0],
value: '"' + file_list[i][0] + '"',
});
}
var tribute = new Tribute({
values: values,
noMatchTemplate: "",
allowSpaces: true,
})
input_box = document.querySelector('#chat-input textarea');
tribute.attach(input_box);
}
"""
class File(gr.File):
"""Subclass from gr.File to maintain the original filename
@@ -1429,6 +1448,10 @@ class FileSelector(BasePage):
visible=False,
)
self.selector_user_id = gr.State(value=user_id)
self.selector_choices = gr.JSON(
value=[],
visible=False,
)
def on_register_events(self):
self.mode.change(
@@ -1436,6 +1459,14 @@ class FileSelector(BasePage):
inputs=[self.mode, self._app.user_id],
outputs=[self.selector, self.selector_user_id],
)
# attach special event for the first index
if self._index.id == 1:
self.selector_choices.change(
fn=None,
inputs=[self.selector_choices],
js=update_file_list_js,
show_progress="hidden",
)
def as_gradio_component(self):
return [self.mode, self.selector, self.selector_user_id]
@@ -1468,7 +1499,7 @@ class FileSelector(BasePage):
available_ids = []
if user_id is None:
# not signed in
return gr.update(value=selected_files, choices=options)
return gr.update(value=selected_files, choices=options), options
with Session(engine) as session:
# get file list from Source table
@@ -1501,13 +1532,13 @@ class FileSelector(BasePage):
each for each in selected_files if each in available_ids_set
]
return gr.update(value=selected_files, choices=options)
return gr.update(value=selected_files, choices=options), options
def _on_app_created(self):
self._app.app.load(
self.load_files,
inputs=[self.selector, self._app.user_id],
outputs=[self.selector],
outputs=[self.selector, self.selector_choices],
)
def on_subscribe_public_events(self):
@@ -1516,26 +1547,18 @@ class FileSelector(BasePage):
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"outputs": [self.selector, self.selector_choices],
"show_progress": "hidden",
},
)
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector],
"show_progress": "hidden",
},
)
for event_name in ["onSignIn", "onSignOut"]:
self._app.subscribe_event(
name=event_name,
definition={
"fn": self.load_files,
"inputs": [self.selector, self._app.user_id],
"outputs": [self.selector, self.selector_choices],
"show_progress": "hidden",
},
)

View File

@@ -8,7 +8,7 @@ import gradio as gr
from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from ktem.index.file.ui import File
from ktem.index.file.ui import File, chat_input_focus_js
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
SuggestConvNamePipeline,
)
@@ -22,7 +22,7 @@ from theflow.settings import settings as flowsettings
from kotaemon.base import Document
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
from ...utils import SUPPORTED_LANGUAGE_MAP
from ...utils import SUPPORTED_LANGUAGE_MAP, get_file_names_regex
from .chat_panel import ChatPanel
from .common import STATE
from .control import ConversationControl
@@ -113,6 +113,7 @@ class ChatPage(BasePage):
self.state_plot_history = gr.State([])
self.state_plot_panel = gr.State(None)
self.state_follow_up = gr.State(None)
self.first_selector_choices = gr.State(None)
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
self.chat_control = ConversationControl(self._app)
@@ -130,6 +131,11 @@ class ChatPage(BasePage):
):
index_ui.render()
gr_index = index_ui.as_gradio_component()
# get the file selector choices for the first index
if index_id == 0:
self.first_selector_choices = index_ui.selector_choices
if gr_index:
if isinstance(gr_index, list):
index.selector = tuple(
@@ -272,6 +278,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self.state_follow_up,
self.first_selector_choices,
],
outputs=[
self.chat_panel.text_input,
@@ -280,6 +287,9 @@ class ChatPage(BasePage):
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.state_follow_up,
# file selector from the first index
self._indices_input[0],
self._indices_input[1],
],
concurrency_limit=20,
show_progress="hidden",
@@ -426,6 +436,10 @@ class ChatPage(BasePage):
fn=self._json_to_plot,
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
fn=None,
inputs=None,
js=chat_input_focus_js,
)
self.chat_control.btn_del.click(
@@ -516,7 +530,12 @@ class ChatPage(BasePage):
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
)
# evidence display on message selection
@@ -535,7 +554,12 @@ class ChatPage(BasePage):
inputs=self.state_plot_panel,
outputs=self.plot_panel,
).then(
fn=None, inputs=None, outputs=None, js=pdfview_js
fn=lambda: True,
inputs=None,
outputs=[self._preview_links],
js=pdfview_js,
).then(
fn=None, inputs=None, outputs=None, js=chat_input_focus_js
)
self.chat_control.cb_is_public.change(
@@ -585,7 +609,14 @@ class ChatPage(BasePage):
)
def submit_msg(
self, chat_input, chat_history, user_id, conv_id, conv_name, chat_suggest
self,
chat_input,
chat_history,
user_id,
conv_id,
conv_name,
chat_suggest,
first_selector_choices,
):
"""Submit a message to the chatbot"""
if not chat_input:
@@ -593,6 +624,24 @@ class ChatPage(BasePage):
chat_input_text = chat_input.get("text", "")
# get all file names with pattern @"filename" in input_str
file_names, chat_input_text = get_file_names_regex(chat_input_text)
first_selector_choices_map = {
item[0]: item[1] for item in first_selector_choices
}
file_ids = []
if file_names:
for file_name in file_names:
file_id = first_selector_choices_map.get(file_name)
if file_id:
file_ids.append(file_id)
if file_ids:
selector_output = ["select", file_ids]
else:
selector_output = [gr.update(), gr.update()]
# check if regen mode is active
if chat_input_text:
chat_history = chat_history + [(chat_input_text, None)]
@@ -620,14 +669,14 @@ class ChatPage(BasePage):
new_conv_name = conv_name
new_chat_suggestion = chat_suggest
return (
return [
{},
chat_history,
new_conv_id,
conv_update,
new_conv_name,
new_chat_suggestion,
)
] + selector_output
def toggle_delete(self, conv_id):
if conv_id:

View File

@@ -25,7 +25,7 @@ class ChatPanel(BasePage):
interactive=True,
scale=20,
file_count="multiple",
placeholder="Chat input",
placeholder="Type a message (or tag a file with @filename)",
container=False,
show_label=False,
elem_id="chat-input",

View File

@@ -410,7 +410,11 @@ class FullQAPipeline(BaseReasoning):
"name": "Citation style",
"value": "highlight",
"component": "radio",
"choices": ["highlight", "inline", "off"],
"choices": [
("highlight (long answer)", "highlight"),
("inline (precise answer)", "inline"),
("off", "off"),
],
},
"create_mindmap": {
"name": "Create Mindmap",

View File

@@ -1,3 +1,4 @@
from .conversation import get_file_names_regex
from .lang import SUPPORTED_LANGUAGE_MAP
__all__ = ["SUPPORTED_LANGUAGE_MAP"]
__all__ = ["SUPPORTED_LANGUAGE_MAP", "get_file_names_regex"]

View File

@@ -1,3 +1,6 @@
import re
def sync_retrieval_n_message(
messages: list[list[str]],
retrievals: list[str],
@@ -16,5 +19,15 @@ def sync_retrieval_n_message(
return retrievals
def get_file_names_regex(input_str: str) -> tuple[list[str], str]:
# get all file names with pattern @"filename" in input_str
# also remove these file names from input_str
pattern = r'@"([^"]*)"'
matches = re.findall(pattern, input_str)
input_str = re.sub(pattern, "", input_str).strip()
return matches, input_str
if __name__ == "__main__":
print(sync_retrieval_n_message([[""], [""], [""]], []))