feat: add file grouping feature (#416) bump:patch

This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-10-21 12:47:18 +07:00 committed by GitHub
parent 2bc1b01876
commit 764fe595f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 413 additions and 81 deletions

View File

@ -119,6 +119,23 @@ class FileIndex(BaseIndex):
"user": Column(Integer, default=1),
},
)
FileGroup = type(
"FileGroupTable",
(Base,),
{
"__tablename__": f"index__{self.id}__group",
"id": Column(Integer, primary_key=True, autoincrement=True),
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
"name": Column(String, unique=True),
"user": Column(Integer, default=1),
"data": Column(
MutableDict.as_mutable(JSON), # type: ignore
default={"files": []},
),
},
)
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
self._docstore: BaseDocumentStore = get_docstore(f"index_{self.id}")
@ -126,6 +143,7 @@ class FileIndex(BaseIndex):
self._resources = {
"Source": Source,
"Index": Index,
"FileGroup": FileGroup,
"VectorStore": self._vs,
"DocStore": self._docstore,
"FileStoragePath": self._fs_path,
@ -297,6 +315,7 @@ class FileIndex(BaseIndex):
self._setup_resources()
self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore
self._resources["FileGroup"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True)
def on_delete(self):
@ -306,6 +325,7 @@ class FileIndex(BaseIndex):
self._setup_resources()
self._resources["Source"].__table__.drop(engine) # type: ignore
self._resources["Index"].__table__.drop(engine) # type: ignore
self._resources["FileGroup"].__table__.drop(engine) # type: ignore
self._vs.drop()
self._docstore.drop()
shutil.rmtree(self._fs_path)

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import json
import logging
import shutil
import threading
@ -120,6 +121,16 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
text: the text to retrieve similar documents
doc_ids: list of document ids to constraint the retrieval
"""
# flatten doc_ids in case of group of doc_ids are passed
if doc_ids:
flatten_doc_ids = []
for doc_id in doc_ids:
if doc_id.startswith("["):
flatten_doc_ids.extend(json.loads(doc_id))
else:
flatten_doc_ids.append(doc_id)
doc_ids = flatten_doc_ids
print("searching in doc_ids", doc_ids)
if not doc_ids:
logger.info(f"Skip retrieval because of no selected files: {self}")

View File

@ -1,4 +1,5 @@
import html
import json
import os
import shutil
import tempfile
@ -19,6 +20,7 @@ from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
DOWNLOAD_MESSAGE = "Press again to download"
MAX_FILENAME_LENGTH = 20
class File(gr.File):
@ -107,6 +109,117 @@ class FileIndexPage(BasePage):
return ""
def render_file_list(self):
self.filter = gr.Textbox(
value="",
label="Filter by name:",
info=(
"(1) Case-insensitive. "
"(2) Search with empty string to show all files."
),
)
self.file_list_state = gr.State(value=None)
self.file_list = gr.DataFrame(
headers=[
"id",
"name",
"size",
"tokens",
"loader",
"date_created",
],
column_widths=["0%", "50%", "8%", "7%", "15%", "20%"],
interactive=False,
wrap=False,
elem_id="file_list_view",
)
with gr.Row():
self.deselect_button = gr.Button(
"Close",
visible=False,
)
self.is_zipped_state = gr.State(value=False)
self.download_single_button = gr.DownloadButton(
"Download file",
visible=False,
)
self.delete_button = gr.Button(
"Delete",
variant="stop",
visible=False,
)
with gr.Row() as self.selection_info:
self.selected_file_id = gr.State(value=None)
with gr.Column(scale=2):
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.chunks = gr.HTML(visible=False)
with gr.Accordion("Advance options", open=False):
with gr.Row():
self.download_all_button = gr.DownloadButton(
"Download all files",
visible=True,
)
self.delete_all_button = gr.Button(
"Delete all files",
variant="stop",
visible=True,
)
self.delete_all_button_confirm = gr.Button(
"Confirm delete", variant="stop", visible=False
)
self.delete_all_button_cancel = gr.Button("Cancel", visible=False)
def render_group_list(self):
self.group_list_state = gr.State(value=None)
self.group_list = gr.DataFrame(
headers=[
"id",
"name",
"files",
"date_created",
],
column_widths=["0%", "25%", "55%", "20%"],
interactive=False,
wrap=False,
)
with gr.Row():
self.group_add_button = gr.Button(
"Add",
variant="primary",
)
self.group_close_button = gr.Button(
"Close",
visible=False,
)
self.group_delete_button = gr.Button(
"Delete",
variant="stop",
visible=False,
)
with gr.Column(visible=False) as self._group_info_panel:
self.group_label = gr.Markdown()
self.group_name = gr.Textbox(
label="Group name",
placeholder="Group name",
lines=1,
max_lines=1,
interactive=False,
)
self.group_files = gr.Dropdown(
label="Attached files",
multiselect=True,
)
self.group_save_button = gr.Button(
"Save",
variant="primary",
)
def on_building_ui(self):
"""Build the UI of the app"""
with gr.Row():
@ -157,76 +270,24 @@ class FileIndexPage(BasePage):
elem_classes=["right-button"],
)
gr.Markdown("## File List")
self.filter = gr.Textbox(
value="",
label="Filter by name:",
info=(
"(1) Case-insensitive. "
"(2) Search with empty string to show all files."
),
)
self.file_list_state = gr.State(value=None)
self.file_list = gr.DataFrame(
headers=[
"id",
"name",
"size",
"tokens",
"loader",
"date_created",
],
column_widths=["0%", "50%", "8%", "7%", "15%", "20%"],
interactive=False,
wrap=False,
elem_id="file_list_view",
)
with gr.Tab("Files"):
self.render_file_list()
with gr.Row():
self.deselect_button = gr.Button(
"Close",
visible=False,
)
self.delete_button = gr.Button(
"Delete",
variant="stop",
visible=False,
)
with gr.Row():
self.is_zipped_state = gr.State(value=False)
self.download_single_button = gr.DownloadButton(
"Download file",
visible=False,
)
with gr.Row() as self.selection_info:
self.selected_file_id = gr.State(value=None)
with gr.Column(scale=2):
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.chunks = gr.HTML(visible=False)
with gr.Accordion("Advance options", open=False):
with gr.Row():
self.download_all_button = gr.DownloadButton(
"Download all files",
visible=True,
)
self.delete_all_button = gr.Button(
"Delete all files",
variant="stop",
visible=True,
)
self.delete_all_button_confirm = gr.Button(
"Confirm delete", variant="stop", visible=False
)
self.delete_all_button_cancel = gr.Button(
"Cancel", visible=False
)
with gr.Tab("Groups"):
self.render_group_list()
def on_subscribe_public_events(self):
"""Subscribe to the declared public event of the app"""
self._app.subscribe_event(
name=f"onFileIndex{self._index.id}Changed",
definition={
"fn": self.list_file_names,
"inputs": [self.file_list_state],
"outputs": [self.group_files],
"show_progress": "hidden",
},
)
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
@ -237,6 +298,24 @@ class FileIndexPage(BasePage):
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.list_group,
"inputs": [self._app.user_id, self.file_list_state],
"outputs": [self.group_list_state, self.group_list],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.list_file_names,
"inputs": [self.file_list_state],
"outputs": [self.group_files],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
@ -525,20 +604,27 @@ class FileIndexPage(BasePage):
show_progress="hidden",
)
onUploaded = self.upload_button.click(
fn=lambda: gr.update(visible=True),
outputs=[self.upload_progress_panel],
).then(
fn=self.index_fn,
inputs=[
self.files,
self.urls,
self.reindex,
self._app.settings_state,
self._app.user_id,
],
outputs=[self.upload_result, self.upload_info],
concurrency_limit=20,
onUploaded = (
self.upload_button.click(
fn=lambda: gr.update(visible=True),
outputs=[self.upload_progress_panel],
)
.then(
fn=self.index_fn,
inputs=[
self.files,
self.urls,
self.reindex,
self._app.settings_state,
self._app.user_id,
],
outputs=[self.upload_result, self.upload_info],
concurrency_limit=20,
)
.then(
fn=lambda: gr.update(value=""),
outputs=[self.urls],
)
)
try:
@ -631,6 +717,26 @@ class FileIndexPage(BasePage):
show_progress="hidden",
)
self.group_list.select(
fn=self.interact_group_list,
inputs=[self.group_list_state],
outputs=[self.group_label, self.group_name, self.group_files],
show_progress="hidden",
).then(
fn=lambda: (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
),
outputs=[
self._group_info_panel,
self.group_add_button,
self.group_close_button,
self.group_delete_button,
],
)
self.filter.submit(
fn=self.list_file,
inputs=[self._app.user_id, self.filter],
@ -638,6 +744,58 @@ class FileIndexPage(BasePage):
show_progress="hidden",
)
self.group_add_button.click(
fn=lambda: [
gr.update(visible=False),
gr.update(value="### Add new group"),
gr.update(visible=True),
gr.update(value="", interactive=True),
gr.update(value=[]),
],
outputs=[
self.group_add_button,
self.group_label,
self._group_info_panel,
self.group_name,
self.group_files,
],
)
onGroupSaved = self.group_save_button.click(
fn=self.save_group,
inputs=[self.group_name, self.group_files, self._app.user_id],
).then(
self.list_group,
inputs=[self._app.user_id, self.file_list_state],
outputs=[self.group_list_state, self.group_list],
)
self.group_close_button.click(
fn=lambda: [
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
gr.update(visible=False),
],
outputs=[
self.group_add_button,
self._group_info_panel,
self.group_close_button,
self.group_delete_button,
],
)
onGroupDeleted = self.group_delete_button.click(
fn=self.delete_group,
inputs=[self.group_name],
).then(
self.list_group,
inputs=[self._app.user_id, self.file_list_state],
outputs=[self.group_list_state, self.group_list],
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onGroupDeleted = onGroupDeleted.then(**event)
onGroupSaved = onGroupSaved.then(**event)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
@ -926,6 +1084,125 @@ class FileIndexPage(BasePage):
return results, file_list
def list_file_names(self, file_list_state):
if file_list_state:
file_names = [(item["name"], item["id"]) for item in file_list_state]
else:
file_names = []
return gr.update(choices=file_names)
def list_group(self, user_id, file_list):
if file_list:
file_id_to_name = {item["id"]: item["name"] for item in file_list}
else:
file_id_to_name = {}
if user_id is None:
# not signed in
return [], pd.DataFrame.from_records(
[
{
"id": "-",
"name": "-",
"files": "-",
"date_created": "-",
}
]
)
FileGroup = self._index._resources["FileGroup"]
with Session(engine) as session:
statement = select(FileGroup)
if self._index.config.get("private", False):
statement = statement.where(FileGroup.user == user_id)
results = [
{
"id": each[0].id,
"name": each[0].name,
"files": each[0].data.get("files", []),
"date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"),
}
for each in session.execute(statement).all()
]
if results:
formated_results = deepcopy(results)
for item in formated_results:
file_names = [
file_id_to_name.get(file_id, "-") for file_id in item["files"]
]
item["files"] = ", ".join(
f"'{it[:MAX_FILENAME_LENGTH]}..'"
if len(it) > MAX_FILENAME_LENGTH
else f"'{it}'"
for it in file_names
)
item_count = len(file_names)
item_postfix = "s" if item_count > 1 else ""
item["files"] = f"[{item_count} item{item_postfix}] " + item["files"]
group_list = pd.DataFrame.from_records(formated_results)
else:
group_list = pd.DataFrame.from_records(
[
{
"id": "-",
"name": "-",
"files": "-",
"date_created": "-",
}
]
)
return results, group_list
def save_group(self, group_name, group_files, user_id):
FileGroup = self._index._resources["FileGroup"]
current_group = None
# check if group_name exist
with Session(engine) as session:
current_group = session.query(FileGroup).filter_by(name=group_name).first()
if not current_group:
current_group = FileGroup(
name=group_name,
data={"files": group_files}, # type: ignore
user=user_id,
)
session.add(current_group)
session.commit()
else:
# update current group with new info
current_group.name = group_name
current_group.data["files"] = group_files # Update the files
session.commit()
group_id = current_group.id
gr.Info(f"Group {group_name} has been saved")
return group_id
def delete_group(self, group_name):
FileGroup = self._index._resources["FileGroup"]
group_id = None
with Session(engine) as session:
group = session.execute(
select(FileGroup).where(FileGroup.name == group_name)
).first()
if group:
item = group[0]
group_id = item.id
session.delete(item)
session.commit()
gr.Info(f"Group {group_name} has been deleted")
else:
raise gr.Error(f"Group {group_name} not found")
return group_id
def interact_file_list(self, list_files, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No file is uploaded")
@ -938,6 +1215,18 @@ class FileIndexPage(BasePage):
name=list_files["name"][ev.index[0]]
)
def interact_group_list(self, list_groups, ev: gr.SelectData):
selected_id = ev.index[0]
if (not ev.value or ev.value == "-") and selected_id == 0:
raise gr.Error("No group is selected")
selected_item = list_groups[selected_id]
return (
"### Group Information",
gr.update(value=selected_item["name"], interactive=False),
selected_item["files"],
)
def validate(self, files: list[str]):
"""Validate if the files are valid"""
paths = [Path(file) for file in files]
@ -1044,9 +1333,9 @@ class FileSelector(BasePage):
return gr.update(value=selected_files, choices=options)
with Session(engine) as session:
# get file list from Source table
statement = select(self._index._resources["Source"])
if self._index.config.get("private", False):
statement = statement.where(
self._index._resources["Source"].user == user_id
)
@ -1056,6 +1345,18 @@ class FileSelector(BasePage):
available_ids.append(result[0].id)
options.append((result[0].name, result[0].id))
# get group list from FileGroup table
FileGroup = self._index._resources["FileGroup"]
statement = select(FileGroup)
if self._index.config.get("private", False):
statement = statement.where(FileGroup.user == user_id)
results = session.execute(statement).all()
for result in results:
item = result[0]
options.append(
(f"group: '{item.name}'", json.dumps(item.data.get("files", [])))
)
if selected_files:
available_ids_set = set(available_ids)
selected_files = [