diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py index 0cbd57f..896bc81 100644 --- a/libs/ktem/ktem/index/file/index.py +++ b/libs/ktem/ktem/index/file/index.py @@ -119,6 +119,23 @@ class FileIndex(BaseIndex): "user": Column(Integer, default=1), }, ) + FileGroup = type( + "FileGroupTable", + (Base,), + { + "__tablename__": f"index__{self.id}__group", + "id": Column(Integer, primary_key=True, autoincrement=True), + "date_created": Column( + DateTime(timezone=True), server_default=func.now() + ), + "name": Column(String, unique=True), + "user": Column(Integer, default=1), + "data": Column( + MutableDict.as_mutable(JSON), # type: ignore + default={"files": []}, + ), + }, + ) self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}") self._docstore: BaseDocumentStore = get_docstore(f"index_{self.id}") @@ -126,6 +143,7 @@ class FileIndex(BaseIndex): self._resources = { "Source": Source, "Index": Index, + "FileGroup": FileGroup, "VectorStore": self._vs, "DocStore": self._docstore, "FileStoragePath": self._fs_path, @@ -297,6 +315,7 @@ class FileIndex(BaseIndex): self._setup_resources() self._resources["Source"].metadata.create_all(engine) # type: ignore self._resources["Index"].metadata.create_all(engine) # type: ignore + self._resources["FileGroup"].metadata.create_all(engine) # type: ignore self._fs_path.mkdir(parents=True, exist_ok=True) def on_delete(self): @@ -306,6 +325,7 @@ class FileIndex(BaseIndex): self._setup_resources() self._resources["Source"].__table__.drop(engine) # type: ignore self._resources["Index"].__table__.drop(engine) # type: ignore + self._resources["FileGroup"].__table__.drop(engine) # type: ignore self._vs.drop() self._docstore.drop() shutil.rmtree(self._fs_path) diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index b28e243..c7cf234 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json import logging import shutil import threading @@ -120,6 +121,16 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): text: the text to retrieve similar documents doc_ids: list of document ids to constraint the retrieval """ + # flatten doc_ids in case of group of doc_ids are passed + if doc_ids: + flatten_doc_ids = [] + for doc_id in doc_ids: + if doc_id.startswith("["): + flatten_doc_ids.extend(json.loads(doc_id)) + else: + flatten_doc_ids.append(doc_id) + doc_ids = flatten_doc_ids + print("searching in doc_ids", doc_ids) if not doc_ids: logger.info(f"Skip retrieval because of no selected files: {self}") diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index 654bdf1..2779335 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -1,4 +1,5 @@ import html +import json import os import shutil import tempfile @@ -19,6 +20,7 @@ from sqlalchemy.orm import Session from theflow.settings import settings as flowsettings DOWNLOAD_MESSAGE = "Press again to download" +MAX_FILENAME_LENGTH = 20 class File(gr.File): @@ -107,6 +109,117 @@ class FileIndexPage(BasePage): return "" + def render_file_list(self): + self.filter = gr.Textbox( + value="", + label="Filter by name:", + info=( + "(1) Case-insensitive. " + "(2) Search with empty string to show all files." + ), + ) + self.file_list_state = gr.State(value=None) + self.file_list = gr.DataFrame( + headers=[ + "id", + "name", + "size", + "tokens", + "loader", + "date_created", + ], + column_widths=["0%", "50%", "8%", "7%", "15%", "20%"], + interactive=False, + wrap=False, + elem_id="file_list_view", + ) + + with gr.Row(): + self.deselect_button = gr.Button( + "Close", + visible=False, + ) + self.is_zipped_state = gr.State(value=False) + self.download_single_button = gr.DownloadButton( + "Download file", + visible=False, + ) + self.delete_button = gr.Button( + "Delete", + variant="stop", + visible=False, + ) + + with gr.Row() as self.selection_info: + self.selected_file_id = gr.State(value=None) + with gr.Column(scale=2): + self.selected_panel = gr.Markdown(self.selected_panel_false) + + self.chunks = gr.HTML(visible=False) + + with gr.Accordion("Advance options", open=False): + with gr.Row(): + self.download_all_button = gr.DownloadButton( + "Download all files", + visible=True, + ) + self.delete_all_button = gr.Button( + "Delete all files", + variant="stop", + visible=True, + ) + self.delete_all_button_confirm = gr.Button( + "Confirm delete", variant="stop", visible=False + ) + self.delete_all_button_cancel = gr.Button("Cancel", visible=False) + + def render_group_list(self): + self.group_list_state = gr.State(value=None) + self.group_list = gr.DataFrame( + headers=[ + "id", + "name", + "files", + "date_created", + ], + column_widths=["0%", "25%", "55%", "20%"], + interactive=False, + wrap=False, + ) + + with gr.Row(): + self.group_add_button = gr.Button( + "Add", + variant="primary", + ) + self.group_close_button = gr.Button( + "Close", + visible=False, + ) + self.group_delete_button = gr.Button( + "Delete", + variant="stop", + visible=False, + ) + + with gr.Column(visible=False) as self._group_info_panel: + self.group_label = gr.Markdown() + self.group_name = gr.Textbox( + label="Group name", + placeholder="Group name", + lines=1, + max_lines=1, + interactive=False, + ) + self.group_files = gr.Dropdown( + label="Attached files", + multiselect=True, + ) + self.group_save_button = gr.Button( + "Save", + variant="primary", + ) + def on_building_ui(self): """Build the UI of the app""" with gr.Row(): @@ -157,76 +270,24 @@ class FileIndexPage(BasePage): elem_classes=["right-button"], ) - gr.Markdown("## File List") - self.filter = gr.Textbox( - value="", - label="Filter by name:", - info=( - "(1) Case-insensitive. " - "(2) Search with empty string to show all files." - ), - ) - self.file_list_state = gr.State(value=None) - self.file_list = gr.DataFrame( - headers=[ - "id", - "name", - "size", - "tokens", - "loader", - "date_created", - ], - column_widths=["0%", "50%", "8%", "7%", "15%", "20%"], - interactive=False, - wrap=False, - elem_id="file_list_view", - ) + with gr.Tab("Files"): + self.render_file_list() - with gr.Row(): - self.deselect_button = gr.Button( - "Close", - visible=False, - ) - self.delete_button = gr.Button( - "Delete", - variant="stop", - visible=False, - ) - with gr.Row(): - self.is_zipped_state = gr.State(value=False) - - self.download_single_button = gr.DownloadButton( - "Download file", - visible=False, - ) - - with gr.Row() as self.selection_info: - self.selected_file_id = gr.State(value=None) - with gr.Column(scale=2): - self.selected_panel = gr.Markdown(self.selected_panel_false) - - self.chunks = gr.HTML(visible=False) - - with gr.Accordion("Advance options", open=False): - with gr.Row(): - self.download_all_button = gr.DownloadButton( - "Download all files", - visible=True, - ) - self.delete_all_button = gr.Button( - "Delete all files", - variant="stop", - visible=True, - ) - self.delete_all_button_confirm = gr.Button( - "Confirm delete", variant="stop", visible=False - ) - self.delete_all_button_cancel = gr.Button( - "Cancel", visible=False - ) + with gr.Tab("Groups"): + self.render_group_list() def on_subscribe_public_events(self): """Subscribe to the declared public event of the app""" + self._app.subscribe_event( + name=f"onFileIndex{self._index.id}Changed", + definition={ + "fn": self.list_file_names, + "inputs": [self.file_list_state], + "outputs": [self.group_files], + "show_progress": "hidden", + }, + ) + if self._app.f_user_management: self._app.subscribe_event( name="onSignIn", @@ -237,6 +298,24 @@ class FileIndexPage(BasePage): "show_progress": "hidden", }, ) + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.list_group, + "inputs": [self._app.user_id, self.file_list_state], + "outputs": [self.group_list_state, self.group_list], + "show_progress": "hidden", + }, + ) + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.list_file_names, + "inputs": [self.file_list_state], + "outputs": [self.group_files], + "show_progress": "hidden", + }, + ) self._app.subscribe_event( name="onSignOut", definition={ @@ -525,20 +604,27 @@ class FileIndexPage(BasePage): show_progress="hidden", ) - onUploaded = self.upload_button.click( - fn=lambda: gr.update(visible=True), - outputs=[self.upload_progress_panel], - ).then( - fn=self.index_fn, - inputs=[ - self.files, - self.urls, - self.reindex, - self._app.settings_state, - self._app.user_id, - ], - outputs=[self.upload_result, self.upload_info], - concurrency_limit=20, + onUploaded = ( + self.upload_button.click( + fn=lambda: gr.update(visible=True), + outputs=[self.upload_progress_panel], + ) + .then( + fn=self.index_fn, + inputs=[ + self.files, + self.urls, + self.reindex, + self._app.settings_state, + self._app.user_id, + ], + outputs=[self.upload_result, self.upload_info], + concurrency_limit=20, + ) + .then( + fn=lambda: gr.update(value=""), + outputs=[self.urls], + ) ) try: @@ -631,6 +717,26 @@ class FileIndexPage(BasePage): show_progress="hidden", ) + self.group_list.select( + fn=self.interact_group_list, + inputs=[self.group_list_state], + outputs=[self.group_label, self.group_name, self.group_files], + show_progress="hidden", + ).then( + fn=lambda: ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=True), + gr.update(visible=True), + ), + outputs=[ + self._group_info_panel, + self.group_add_button, + self.group_close_button, + self.group_delete_button, + ], + ) + self.filter.submit( fn=self.list_file, inputs=[self._app.user_id, self.filter], @@ -638,6 +744,58 @@ class FileIndexPage(BasePage): show_progress="hidden", ) + self.group_add_button.click( + fn=lambda: [ + gr.update(visible=False), + gr.update(value="### Add new group"), + gr.update(visible=True), + gr.update(value="", interactive=True), + gr.update(value=[]), + ], + outputs=[ + self.group_add_button, + self.group_label, + self._group_info_panel, + self.group_name, + self.group_files, + ], + ) + + onGroupSaved = self.group_save_button.click( + fn=self.save_group, + inputs=[self.group_name, self.group_files, self._app.user_id], + ).then( + self.list_group, + inputs=[self._app.user_id, self.file_list_state], + outputs=[self.group_list_state, self.group_list], + ) + self.group_close_button.click( + fn=lambda: [ + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False), + gr.update(visible=False), + ], + outputs=[ + self.group_add_button, + self._group_info_panel, + self.group_close_button, + self.group_delete_button, + ], + ) + onGroupDeleted = self.group_delete_button.click( + fn=self.delete_group, + inputs=[self.group_name], + ).then( + self.list_group, + inputs=[self._app.user_id, self.file_list_state], + outputs=[self.group_list_state, self.group_list], + ) + + for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): + onGroupDeleted = onGroupDeleted.then(**event) + onGroupSaved = onGroupSaved.then(**event) + def _on_app_created(self): """Called when the app is created""" self._app.app.load( @@ -926,6 +1084,125 @@ class FileIndexPage(BasePage): return results, file_list + def list_file_names(self, file_list_state): + if file_list_state: + file_names = [(item["name"], item["id"]) for item in file_list_state] + else: + file_names = [] + + return gr.update(choices=file_names) + + def list_group(self, user_id, file_list): + if file_list: + file_id_to_name = {item["id"]: item["name"] for item in file_list} + else: + file_id_to_name = {} + + if user_id is None: + # not signed in + return [], pd.DataFrame.from_records( + [ + { + "id": "-", + "name": "-", + "files": "-", + "date_created": "-", + } + ] + ) + + FileGroup = self._index._resources["FileGroup"] + with Session(engine) as session: + statement = select(FileGroup) + if self._index.config.get("private", False): + statement = statement.where(FileGroup.user == user_id) + + results = [ + { + "id": each[0].id, + "name": each[0].name, + "files": each[0].data.get("files", []), + "date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"), + } + for each in session.execute(statement).all() + ] + + if results: + formated_results = deepcopy(results) + for item in formated_results: + file_names = [ + file_id_to_name.get(file_id, "-") for file_id in item["files"] + ] + item["files"] = ", ".join( + f"'{it[:MAX_FILENAME_LENGTH]}..'" + if len(it) > MAX_FILENAME_LENGTH + else f"'{it}'" + for it in file_names + ) + item_count = len(file_names) + item_postfix = "s" if item_count > 1 else "" + item["files"] = f"[{item_count} item{item_postfix}] " + item["files"] + + group_list = pd.DataFrame.from_records(formated_results) + else: + group_list = pd.DataFrame.from_records( + [ + { + "id": "-", + "name": "-", + "files": "-", + "date_created": "-", + } + ] + ) + + return results, group_list + + def save_group(self, group_name, group_files, user_id): + FileGroup = self._index._resources["FileGroup"] + current_group = None + + # check if group_name exist + with Session(engine) as session: + current_group = session.query(FileGroup).filter_by(name=group_name).first() + + if not current_group: + current_group = FileGroup( + name=group_name, + data={"files": group_files}, # type: ignore + user=user_id, + ) + session.add(current_group) + session.commit() + else: + # update current group with new info + current_group.name = group_name + current_group.data["files"] = group_files # Update the files + session.commit() + + group_id = current_group.id + + gr.Info(f"Group {group_name} has been saved") + return group_id + + def delete_group(self, group_name): + FileGroup = self._index._resources["FileGroup"] + group_id = None + with Session(engine) as session: + group = session.execute( + select(FileGroup).where(FileGroup.name == group_name) + ).first() + if group: + item = group[0] + group_id = item.id + session.delete(item) + session.commit() + gr.Info(f"Group {group_name} has been deleted") + else: + raise gr.Error(f"Group {group_name} not found") + + return group_id + def interact_file_list(self, list_files, ev: gr.SelectData): if ev.value == "-" and ev.index[0] == 0: gr.Info("No file is uploaded") @@ -938,6 +1215,18 @@ class FileIndexPage(BasePage): name=list_files["name"][ev.index[0]] ) + def interact_group_list(self, list_groups, ev: gr.SelectData): + selected_id = ev.index[0] + if (not ev.value or ev.value == "-") and selected_id == 0: + raise gr.Error("No group is selected") + + selected_item = list_groups[selected_id] + return ( + "### Group Information", + gr.update(value=selected_item["name"], interactive=False), + selected_item["files"], + ) + def validate(self, files: list[str]): """Validate if the files are valid""" paths = [Path(file) for file in files] @@ -1044,9 +1333,9 @@ class FileSelector(BasePage): return gr.update(value=selected_files, choices=options) with Session(engine) as session: + # get file list from Source table statement = select(self._index._resources["Source"]) if self._index.config.get("private", False): - statement = statement.where( self._index._resources["Source"].user == user_id ) @@ -1056,6 +1345,18 @@ class FileSelector(BasePage): available_ids.append(result[0].id) options.append((result[0].name, result[0].id)) + # get group list from FileGroup table + FileGroup = self._index._resources["FileGroup"] + statement = select(FileGroup) + if self._index.config.get("private", False): + statement = statement.where(FileGroup.user == user_id) + results = session.execute(statement).all() + for result in results: + item = result[0] + options.append( + (f"group: '{item.name}'", json.dumps(item.data.get("files", []))) + ) + if selected_files: available_ids_set = set(available_ids) selected_files = [