feat: add file grouping feature (#416) bump:patch
This commit is contained in:
parent
2bc1b01876
commit
764fe595f4
|
@ -119,6 +119,23 @@ class FileIndex(BaseIndex):
|
|||
"user": Column(Integer, default=1),
|
||||
},
|
||||
)
|
||||
FileGroup = type(
|
||||
"FileGroupTable",
|
||||
(Base,),
|
||||
{
|
||||
"__tablename__": f"index__{self.id}__group",
|
||||
"id": Column(Integer, primary_key=True, autoincrement=True),
|
||||
"date_created": Column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
),
|
||||
"name": Column(String, unique=True),
|
||||
"user": Column(Integer, default=1),
|
||||
"data": Column(
|
||||
MutableDict.as_mutable(JSON), # type: ignore
|
||||
default={"files": []},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
|
||||
self._docstore: BaseDocumentStore = get_docstore(f"index_{self.id}")
|
||||
|
@ -126,6 +143,7 @@ class FileIndex(BaseIndex):
|
|||
self._resources = {
|
||||
"Source": Source,
|
||||
"Index": Index,
|
||||
"FileGroup": FileGroup,
|
||||
"VectorStore": self._vs,
|
||||
"DocStore": self._docstore,
|
||||
"FileStoragePath": self._fs_path,
|
||||
|
@ -297,6 +315,7 @@ class FileIndex(BaseIndex):
|
|||
self._setup_resources()
|
||||
self._resources["Source"].metadata.create_all(engine) # type: ignore
|
||||
self._resources["Index"].metadata.create_all(engine) # type: ignore
|
||||
self._resources["FileGroup"].metadata.create_all(engine) # type: ignore
|
||||
self._fs_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def on_delete(self):
|
||||
|
@ -306,6 +325,7 @@ class FileIndex(BaseIndex):
|
|||
self._setup_resources()
|
||||
self._resources["Source"].__table__.drop(engine) # type: ignore
|
||||
self._resources["Index"].__table__.drop(engine) # type: ignore
|
||||
self._resources["FileGroup"].__table__.drop(engine) # type: ignore
|
||||
self._vs.drop()
|
||||
self._docstore.drop()
|
||||
shutil.rmtree(self._fs_path)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import threading
|
||||
|
@ -120,6 +121,16 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
text: the text to retrieve similar documents
|
||||
doc_ids: list of document ids to constraint the retrieval
|
||||
"""
|
||||
# flatten doc_ids in case of group of doc_ids are passed
|
||||
if doc_ids:
|
||||
flatten_doc_ids = []
|
||||
for doc_id in doc_ids:
|
||||
if doc_id.startswith("["):
|
||||
flatten_doc_ids.extend(json.loads(doc_id))
|
||||
else:
|
||||
flatten_doc_ids.append(doc_id)
|
||||
doc_ids = flatten_doc_ids
|
||||
|
||||
print("searching in doc_ids", doc_ids)
|
||||
if not doc_ids:
|
||||
logger.info(f"Skip retrieval because of no selected files: {self}")
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import html
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
|
@ -19,6 +20,7 @@ from sqlalchemy.orm import Session
|
|||
from theflow.settings import settings as flowsettings
|
||||
|
||||
DOWNLOAD_MESSAGE = "Press again to download"
|
||||
MAX_FILENAME_LENGTH = 20
|
||||
|
||||
|
||||
class File(gr.File):
|
||||
|
@ -107,6 +109,117 @@ class FileIndexPage(BasePage):
|
|||
|
||||
return ""
|
||||
|
||||
def render_file_list(self):
|
||||
self.filter = gr.Textbox(
|
||||
value="",
|
||||
label="Filter by name:",
|
||||
info=(
|
||||
"(1) Case-insensitive. "
|
||||
"(2) Search with empty string to show all files."
|
||||
),
|
||||
)
|
||||
self.file_list_state = gr.State(value=None)
|
||||
self.file_list = gr.DataFrame(
|
||||
headers=[
|
||||
"id",
|
||||
"name",
|
||||
"size",
|
||||
"tokens",
|
||||
"loader",
|
||||
"date_created",
|
||||
],
|
||||
column_widths=["0%", "50%", "8%", "7%", "15%", "20%"],
|
||||
interactive=False,
|
||||
wrap=False,
|
||||
elem_id="file_list_view",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.deselect_button = gr.Button(
|
||||
"Close",
|
||||
visible=False,
|
||||
)
|
||||
self.is_zipped_state = gr.State(value=False)
|
||||
self.download_single_button = gr.DownloadButton(
|
||||
"Download file",
|
||||
visible=False,
|
||||
)
|
||||
self.delete_button = gr.Button(
|
||||
"Delete",
|
||||
variant="stop",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Row() as self.selection_info:
|
||||
self.selected_file_id = gr.State(value=None)
|
||||
with gr.Column(scale=2):
|
||||
self.selected_panel = gr.Markdown(self.selected_panel_false)
|
||||
|
||||
self.chunks = gr.HTML(visible=False)
|
||||
|
||||
with gr.Accordion("Advance options", open=False):
|
||||
with gr.Row():
|
||||
self.download_all_button = gr.DownloadButton(
|
||||
"Download all files",
|
||||
visible=True,
|
||||
)
|
||||
self.delete_all_button = gr.Button(
|
||||
"Delete all files",
|
||||
variant="stop",
|
||||
visible=True,
|
||||
)
|
||||
self.delete_all_button_confirm = gr.Button(
|
||||
"Confirm delete", variant="stop", visible=False
|
||||
)
|
||||
self.delete_all_button_cancel = gr.Button("Cancel", visible=False)
|
||||
|
||||
def render_group_list(self):
|
||||
self.group_list_state = gr.State(value=None)
|
||||
self.group_list = gr.DataFrame(
|
||||
headers=[
|
||||
"id",
|
||||
"name",
|
||||
"files",
|
||||
"date_created",
|
||||
],
|
||||
column_widths=["0%", "25%", "55%", "20%"],
|
||||
interactive=False,
|
||||
wrap=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.group_add_button = gr.Button(
|
||||
"Add",
|
||||
variant="primary",
|
||||
)
|
||||
self.group_close_button = gr.Button(
|
||||
"Close",
|
||||
visible=False,
|
||||
)
|
||||
self.group_delete_button = gr.Button(
|
||||
"Delete",
|
||||
variant="stop",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self._group_info_panel:
|
||||
self.group_label = gr.Markdown()
|
||||
self.group_name = gr.Textbox(
|
||||
label="Group name",
|
||||
placeholder="Group name",
|
||||
lines=1,
|
||||
max_lines=1,
|
||||
interactive=False,
|
||||
)
|
||||
self.group_files = gr.Dropdown(
|
||||
label="Attached files",
|
||||
multiselect=True,
|
||||
)
|
||||
self.group_save_button = gr.Button(
|
||||
"Save",
|
||||
variant="primary",
|
||||
)
|
||||
|
||||
def on_building_ui(self):
|
||||
"""Build the UI of the app"""
|
||||
with gr.Row():
|
||||
|
@ -157,76 +270,24 @@ class FileIndexPage(BasePage):
|
|||
elem_classes=["right-button"],
|
||||
)
|
||||
|
||||
gr.Markdown("## File List")
|
||||
self.filter = gr.Textbox(
|
||||
value="",
|
||||
label="Filter by name:",
|
||||
info=(
|
||||
"(1) Case-insensitive. "
|
||||
"(2) Search with empty string to show all files."
|
||||
),
|
||||
)
|
||||
self.file_list_state = gr.State(value=None)
|
||||
self.file_list = gr.DataFrame(
|
||||
headers=[
|
||||
"id",
|
||||
"name",
|
||||
"size",
|
||||
"tokens",
|
||||
"loader",
|
||||
"date_created",
|
||||
],
|
||||
column_widths=["0%", "50%", "8%", "7%", "15%", "20%"],
|
||||
interactive=False,
|
||||
wrap=False,
|
||||
elem_id="file_list_view",
|
||||
)
|
||||
with gr.Tab("Files"):
|
||||
self.render_file_list()
|
||||
|
||||
with gr.Row():
|
||||
self.deselect_button = gr.Button(
|
||||
"Close",
|
||||
visible=False,
|
||||
)
|
||||
self.delete_button = gr.Button(
|
||||
"Delete",
|
||||
variant="stop",
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
self.is_zipped_state = gr.State(value=False)
|
||||
|
||||
self.download_single_button = gr.DownloadButton(
|
||||
"Download file",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
with gr.Row() as self.selection_info:
|
||||
self.selected_file_id = gr.State(value=None)
|
||||
with gr.Column(scale=2):
|
||||
self.selected_panel = gr.Markdown(self.selected_panel_false)
|
||||
|
||||
self.chunks = gr.HTML(visible=False)
|
||||
|
||||
with gr.Accordion("Advance options", open=False):
|
||||
with gr.Row():
|
||||
self.download_all_button = gr.DownloadButton(
|
||||
"Download all files",
|
||||
visible=True,
|
||||
)
|
||||
self.delete_all_button = gr.Button(
|
||||
"Delete all files",
|
||||
variant="stop",
|
||||
visible=True,
|
||||
)
|
||||
self.delete_all_button_confirm = gr.Button(
|
||||
"Confirm delete", variant="stop", visible=False
|
||||
)
|
||||
self.delete_all_button_cancel = gr.Button(
|
||||
"Cancel", visible=False
|
||||
)
|
||||
with gr.Tab("Groups"):
|
||||
self.render_group_list()
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
"""Subscribe to the declared public event of the app"""
|
||||
self._app.subscribe_event(
|
||||
name=f"onFileIndex{self._index.id}Changed",
|
||||
definition={
|
||||
"fn": self.list_file_names,
|
||||
"inputs": [self.file_list_state],
|
||||
"outputs": [self.group_files],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
if self._app.f_user_management:
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
|
@ -237,6 +298,24 @@ class FileIndexPage(BasePage):
|
|||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.list_group,
|
||||
"inputs": [self._app.user_id, self.file_list_state],
|
||||
"outputs": [self.group_list_state, self.group_list],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.list_file_names,
|
||||
"inputs": [self.file_list_state],
|
||||
"outputs": [self.group_files],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
|
@ -525,10 +604,12 @@ class FileIndexPage(BasePage):
|
|||
show_progress="hidden",
|
||||
)
|
||||
|
||||
onUploaded = self.upload_button.click(
|
||||
onUploaded = (
|
||||
self.upload_button.click(
|
||||
fn=lambda: gr.update(visible=True),
|
||||
outputs=[self.upload_progress_panel],
|
||||
).then(
|
||||
)
|
||||
.then(
|
||||
fn=self.index_fn,
|
||||
inputs=[
|
||||
self.files,
|
||||
|
@ -540,6 +621,11 @@ class FileIndexPage(BasePage):
|
|||
outputs=[self.upload_result, self.upload_info],
|
||||
concurrency_limit=20,
|
||||
)
|
||||
.then(
|
||||
fn=lambda: gr.update(value=""),
|
||||
outputs=[self.urls],
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# quick file upload event registration of first Index only
|
||||
|
@ -631,6 +717,26 @@ class FileIndexPage(BasePage):
|
|||
show_progress="hidden",
|
||||
)
|
||||
|
||||
self.group_list.select(
|
||||
fn=self.interact_group_list,
|
||||
inputs=[self.group_list_state],
|
||||
outputs=[self.group_label, self.group_name, self.group_files],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
fn=lambda: (
|
||||
gr.update(visible=True),
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=True),
|
||||
gr.update(visible=True),
|
||||
),
|
||||
outputs=[
|
||||
self._group_info_panel,
|
||||
self.group_add_button,
|
||||
self.group_close_button,
|
||||
self.group_delete_button,
|
||||
],
|
||||
)
|
||||
|
||||
self.filter.submit(
|
||||
fn=self.list_file,
|
||||
inputs=[self._app.user_id, self.filter],
|
||||
|
@ -638,6 +744,58 @@ class FileIndexPage(BasePage):
|
|||
show_progress="hidden",
|
||||
)
|
||||
|
||||
self.group_add_button.click(
|
||||
fn=lambda: [
|
||||
gr.update(visible=False),
|
||||
gr.update(value="### Add new group"),
|
||||
gr.update(visible=True),
|
||||
gr.update(value="", interactive=True),
|
||||
gr.update(value=[]),
|
||||
],
|
||||
outputs=[
|
||||
self.group_add_button,
|
||||
self.group_label,
|
||||
self._group_info_panel,
|
||||
self.group_name,
|
||||
self.group_files,
|
||||
],
|
||||
)
|
||||
|
||||
onGroupSaved = self.group_save_button.click(
|
||||
fn=self.save_group,
|
||||
inputs=[self.group_name, self.group_files, self._app.user_id],
|
||||
).then(
|
||||
self.list_group,
|
||||
inputs=[self._app.user_id, self.file_list_state],
|
||||
outputs=[self.group_list_state, self.group_list],
|
||||
)
|
||||
self.group_close_button.click(
|
||||
fn=lambda: [
|
||||
gr.update(visible=True),
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=False),
|
||||
],
|
||||
outputs=[
|
||||
self.group_add_button,
|
||||
self._group_info_panel,
|
||||
self.group_close_button,
|
||||
self.group_delete_button,
|
||||
],
|
||||
)
|
||||
onGroupDeleted = self.group_delete_button.click(
|
||||
fn=self.delete_group,
|
||||
inputs=[self.group_name],
|
||||
).then(
|
||||
self.list_group,
|
||||
inputs=[self._app.user_id, self.file_list_state],
|
||||
outputs=[self.group_list_state, self.group_list],
|
||||
)
|
||||
|
||||
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
|
||||
onGroupDeleted = onGroupDeleted.then(**event)
|
||||
onGroupSaved = onGroupSaved.then(**event)
|
||||
|
||||
def _on_app_created(self):
|
||||
"""Called when the app is created"""
|
||||
self._app.app.load(
|
||||
|
@ -926,6 +1084,125 @@ class FileIndexPage(BasePage):
|
|||
|
||||
return results, file_list
|
||||
|
||||
def list_file_names(self, file_list_state):
|
||||
if file_list_state:
|
||||
file_names = [(item["name"], item["id"]) for item in file_list_state]
|
||||
else:
|
||||
file_names = []
|
||||
|
||||
return gr.update(choices=file_names)
|
||||
|
||||
def list_group(self, user_id, file_list):
|
||||
if file_list:
|
||||
file_id_to_name = {item["id"]: item["name"] for item in file_list}
|
||||
else:
|
||||
file_id_to_name = {}
|
||||
|
||||
if user_id is None:
|
||||
# not signed in
|
||||
return [], pd.DataFrame.from_records(
|
||||
[
|
||||
{
|
||||
"id": "-",
|
||||
"name": "-",
|
||||
"files": "-",
|
||||
"date_created": "-",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
FileGroup = self._index._resources["FileGroup"]
|
||||
with Session(engine) as session:
|
||||
statement = select(FileGroup)
|
||||
if self._index.config.get("private", False):
|
||||
statement = statement.where(FileGroup.user == user_id)
|
||||
|
||||
results = [
|
||||
{
|
||||
"id": each[0].id,
|
||||
"name": each[0].name,
|
||||
"files": each[0].data.get("files", []),
|
||||
"date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
for each in session.execute(statement).all()
|
||||
]
|
||||
|
||||
if results:
|
||||
formated_results = deepcopy(results)
|
||||
for item in formated_results:
|
||||
file_names = [
|
||||
file_id_to_name.get(file_id, "-") for file_id in item["files"]
|
||||
]
|
||||
item["files"] = ", ".join(
|
||||
f"'{it[:MAX_FILENAME_LENGTH]}..'"
|
||||
if len(it) > MAX_FILENAME_LENGTH
|
||||
else f"'{it}'"
|
||||
for it in file_names
|
||||
)
|
||||
item_count = len(file_names)
|
||||
item_postfix = "s" if item_count > 1 else ""
|
||||
item["files"] = f"[{item_count} item{item_postfix}] " + item["files"]
|
||||
|
||||
group_list = pd.DataFrame.from_records(formated_results)
|
||||
else:
|
||||
group_list = pd.DataFrame.from_records(
|
||||
[
|
||||
{
|
||||
"id": "-",
|
||||
"name": "-",
|
||||
"files": "-",
|
||||
"date_created": "-",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
return results, group_list
|
||||
|
||||
def save_group(self, group_name, group_files, user_id):
|
||||
FileGroup = self._index._resources["FileGroup"]
|
||||
current_group = None
|
||||
|
||||
# check if group_name exist
|
||||
with Session(engine) as session:
|
||||
current_group = session.query(FileGroup).filter_by(name=group_name).first()
|
||||
|
||||
if not current_group:
|
||||
current_group = FileGroup(
|
||||
name=group_name,
|
||||
data={"files": group_files}, # type: ignore
|
||||
user=user_id,
|
||||
)
|
||||
session.add(current_group)
|
||||
session.commit()
|
||||
else:
|
||||
# update current group with new info
|
||||
current_group.name = group_name
|
||||
current_group.data["files"] = group_files # Update the files
|
||||
session.commit()
|
||||
|
||||
group_id = current_group.id
|
||||
|
||||
gr.Info(f"Group {group_name} has been saved")
|
||||
return group_id
|
||||
|
||||
def delete_group(self, group_name):
|
||||
FileGroup = self._index._resources["FileGroup"]
|
||||
group_id = None
|
||||
with Session(engine) as session:
|
||||
group = session.execute(
|
||||
select(FileGroup).where(FileGroup.name == group_name)
|
||||
).first()
|
||||
if group:
|
||||
item = group[0]
|
||||
group_id = item.id
|
||||
session.delete(item)
|
||||
session.commit()
|
||||
gr.Info(f"Group {group_name} has been deleted")
|
||||
else:
|
||||
raise gr.Error(f"Group {group_name} not found")
|
||||
|
||||
return group_id
|
||||
|
||||
def interact_file_list(self, list_files, ev: gr.SelectData):
|
||||
if ev.value == "-" and ev.index[0] == 0:
|
||||
gr.Info("No file is uploaded")
|
||||
|
@ -938,6 +1215,18 @@ class FileIndexPage(BasePage):
|
|||
name=list_files["name"][ev.index[0]]
|
||||
)
|
||||
|
||||
def interact_group_list(self, list_groups, ev: gr.SelectData):
|
||||
selected_id = ev.index[0]
|
||||
if (not ev.value or ev.value == "-") and selected_id == 0:
|
||||
raise gr.Error("No group is selected")
|
||||
|
||||
selected_item = list_groups[selected_id]
|
||||
return (
|
||||
"### Group Information",
|
||||
gr.update(value=selected_item["name"], interactive=False),
|
||||
selected_item["files"],
|
||||
)
|
||||
|
||||
def validate(self, files: list[str]):
|
||||
"""Validate if the files are valid"""
|
||||
paths = [Path(file) for file in files]
|
||||
|
@ -1044,9 +1333,9 @@ class FileSelector(BasePage):
|
|||
return gr.update(value=selected_files, choices=options)
|
||||
|
||||
with Session(engine) as session:
|
||||
# get file list from Source table
|
||||
statement = select(self._index._resources["Source"])
|
||||
if self._index.config.get("private", False):
|
||||
|
||||
statement = statement.where(
|
||||
self._index._resources["Source"].user == user_id
|
||||
)
|
||||
|
@ -1056,6 +1345,18 @@ class FileSelector(BasePage):
|
|||
available_ids.append(result[0].id)
|
||||
options.append((result[0].name, result[0].id))
|
||||
|
||||
# get group list from FileGroup table
|
||||
FileGroup = self._index._resources["FileGroup"]
|
||||
statement = select(FileGroup)
|
||||
if self._index.config.get("private", False):
|
||||
statement = statement.where(FileGroup.user == user_id)
|
||||
results = session.execute(statement).all()
|
||||
for result in results:
|
||||
item = result[0]
|
||||
options.append(
|
||||
(f"group: '{item.name}'", json.dumps(item.data.get("files", [])))
|
||||
)
|
||||
|
||||
if selected_files:
|
||||
available_ids_set = set(available_ids)
|
||||
selected_files = [
|
||||
|
|
Loading…
Reference in New Issue
Block a user