From ecf09b275fc6bd6a4348f63906be7c7d7e99595b Mon Sep 17 00:00:00 2001 From: ian_Cin Date: Wed, 3 Apr 2024 16:33:54 +0700 Subject: [PATCH] Fix UI bugs (#8) * Auto create conversation when the user starts * Add conversation rename rule check * Fix empty name during save * Confirm deleting conversation * Show warning if users don't select file when upload files in the File Index * Feedback when user uploads duplicated file * Limit the file types * Fix valid username * Allow login when username with leading and trailing whitespaces * Improve the user * Disable admin panel for non-admnin user * Refresh user lists after creating/deleting users * Auto logging in * Clear admin information upon signing out * Fix unable to receive uploaded filename that include special characters, like !@#$%^&*().pdf * Set upload validation for FileIndex * Improve user management UI/UIX * Show extraction error when indexing file * Return selected user -1 when signing out * Fix default supported file types in file index * Validate changing password * Allow the selector to contain mulitple gradio components * A more tolerable placeholder screen * Allow chat suggestion box * Increase concurrency limit * Make adobe loader optional * Use BaseReasoning --------- Co-authored-by: trducng --- .gitignore | 1 + .pre-commit-config.yaml | 7 +- libs/kotaemon/kotaemon/indices/qa/citation.py | 14 +- libs/kotaemon/kotaemon/loaders/__init__.py | 3 +- .../kotaemon/kotaemon/loaders/adobe_loader.py | 15 +- libs/kotaemon/kotaemon/loaders/ocr_loader.py | 67 ++++++ libs/ktem/ktem/app.py | 4 +- libs/ktem/ktem/index/base.py | 6 +- libs/ktem/ktem/index/file/base.py | 8 + libs/ktem/ktem/index/file/index.py | 140 +++++++++++-- libs/ktem/ktem/index/file/pipelines.py | 63 +++++- libs/ktem/ktem/index/file/ui.py | 140 ++++++++++--- libs/ktem/ktem/index/manager.py | 32 ++- libs/ktem/ktem/main.py | 38 +++- libs/ktem/ktem/pages/admin/user.py | 187 +++++++++++------ libs/ktem/ktem/pages/chat/__init__.py | 198 ++++++++++++++++-- libs/ktem/ktem/pages/chat/chat_suggestion.py | 26 +++ libs/ktem/ktem/pages/chat/control.py | 97 +++++---- libs/ktem/ktem/pages/chat/report.py | 13 +- libs/ktem/ktem/pages/login.py | 51 ++++- libs/ktem/ktem/pages/settings.py | 17 +- libs/ktem/ktem/reasoning/base.py | 6 +- libs/ktem/ktem/reasoning/simple.py | 58 +++-- 23 files changed, 936 insertions(+), 255 deletions(-) create mode 100644 libs/ktem/ktem/pages/chat/chat_suggestion.py diff --git a/.gitignore b/.gitignore index 0114278..5c91c3e 100644 --- a/.gitignore +++ b/.gitignore @@ -452,6 +452,7 @@ $RECYCLE.BIN/ .theflow/ # End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm +*.py[coid] logs/ .gitsecret/keys/random_seed diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 21356ce..3f68b56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,12 @@ repos: hooks: - id: mypy additional_dependencies: - [types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"] + [ + types-PyYAML==6.0.12.11, + "types-requests", + "sqlmodel", + "types-Markdown", + ] args: ["--check-untyped-defs", "--ignore-missing-imports"] exclude: "^templates/" - repo: https://github.com/codespell-project/codespell diff --git a/libs/kotaemon/kotaemon/indices/qa/citation.py b/libs/kotaemon/kotaemon/indices/qa/citation.py index f1a53c7..3192a07 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation.py @@ -104,18 +104,16 @@ class CitationPipeline(BaseComponent): print("CitationPipeline: invoking LLM") llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs) print("CitationPipeline: finish invoking LLM") + if not llm_output.messages: + return None + function_output = llm_output.messages[0].additional_kwargs["function_call"][ + "arguments" + ] + output = QuestionAnswer.parse_raw(function_output) except Exception as e: print(e) return None - if not llm_output.messages: - return None - - function_output = llm_output.messages[0].additional_kwargs["function_call"][ - "arguments" - ] - output = QuestionAnswer.parse_raw(function_output) - return output async def ainvoke(self, context: str, question: str): diff --git a/libs/kotaemon/kotaemon/loaders/__init__.py b/libs/kotaemon/kotaemon/loaders/__init__.py index 28cb5f3..a59d713 100644 --- a/libs/kotaemon/kotaemon/loaders/__init__.py +++ b/libs/kotaemon/kotaemon/loaders/__init__.py @@ -5,7 +5,7 @@ from .docx_loader import DocxReader from .excel_loader import PandasExcelReader from .html_loader import HtmlReader from .mathpix_loader import MathpixPDFReader -from .ocr_loader import OCRReader +from .ocr_loader import ImageReader, OCRReader from .unstructured_loader import UnstructuredReader __all__ = [ @@ -13,6 +13,7 @@ __all__ = [ "BaseReader", "PandasExcelReader", "MathpixPDFReader", + "ImageReader", "OCRReader", "DirectoryReader", "UnstructuredReader", diff --git a/libs/kotaemon/kotaemon/loaders/adobe_loader.py b/libs/kotaemon/kotaemon/loaders/adobe_loader.py index dd8cbc9..09a802c 100644 --- a/libs/kotaemon/kotaemon/loaders/adobe_loader.py +++ b/libs/kotaemon/kotaemon/loaders/adobe_loader.py @@ -10,14 +10,6 @@ from llama_index.readers.base import BaseReader from kotaemon.base import Document -from .utils.adobe import ( - generate_figure_captions, - load_json, - parse_figure_paths, - parse_table_paths, - request_adobe_service, -) - logger = logging.getLogger(__name__) DEFAULT_VLM_ENDPOINT = ( @@ -74,6 +66,13 @@ class AdobeReader(BaseReader): includes 3 types: text, table, and image """ + from .utils.adobe import ( + generate_figure_captions, + load_json, + parse_figure_paths, + parse_table_paths, + request_adobe_service, + ) filename = file.name filepath = str(Path(file).resolve()) diff --git a/libs/kotaemon/kotaemon/loaders/ocr_loader.py b/libs/kotaemon/kotaemon/loaders/ocr_loader.py index e689717..bb1ac5d 100644 --- a/libs/kotaemon/kotaemon/loaders/ocr_loader.py +++ b/libs/kotaemon/kotaemon/loaders/ocr_loader.py @@ -125,3 +125,70 @@ class OCRReader(BaseReader): ) return documents + + +class ImageReader(BaseReader): + """Read PDF using OCR, with high focus on table extraction + + Example: + ```python + >> from knowledgehub.loaders import OCRReader + >> reader = OCRReader() + >> documents = reader.load_data("path/to/pdf") + ``` + + Args: + endpoint: URL to FullOCR endpoint. If not provided, will look for + environment variable `OCR_READER_ENDPOINT` or use the default + `knowledgehub.loaders.ocr_loader.DEFAULT_OCR_ENDPOINT` + (http://127.0.0.1:8000/v2/ai/infer/) + use_ocr: whether to use OCR to read text (e.g: from images, tables) in the PDF + If False, only the table and text within table cells will be extracted. + """ + + def __init__(self, endpoint: Optional[str] = None): + """Init the OCR reader with OCR endpoint (FullOCR pipeline)""" + super().__init__() + self.ocr_endpoint = endpoint or os.getenv( + "OCR_READER_ENDPOINT", DEFAULT_OCR_ENDPOINT + ) + + def load_data( + self, file_path: Path, extra_info: Optional[dict] = None, **kwargs + ) -> List[Document]: + """Load data using OCR reader + + Args: + file_path (Path): Path to PDF file + debug_path (Path): Path to store debug image output + artifact_path (Path): Path to OCR endpoints artifacts directory + + Returns: + List[Document]: list of documents extracted from the PDF file + """ + file_path = Path(file_path).resolve() + + with file_path.open("rb") as content: + files = {"input": content} + data = {"job_id": uuid4(), "table_only": False} + + # call the API from FullOCR endpoint + if "response_content" in kwargs: + # overriding response content if specified + ocr_results = kwargs["response_content"] + else: + # call original API + resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data) + ocr_results = resp.json()["result"] + + extra_info = extra_info or {} + result = [] + for ocr_result in ocr_results: + result.append( + Document( + content=ocr_result["csv_string"], + metadata=extra_info, + ) + ) + + return result diff --git a/libs/ktem/ktem/app.py b/libs/ktem/ktem/app.py index 9bac904..64e8a9d 100644 --- a/libs/ktem/ktem/app.py +++ b/libs/ktem/ktem/app.py @@ -229,7 +229,9 @@ class BasePage: def _on_app_created(self): """Called when the app is created""" - def as_gradio_component(self) -> Optional[gr.components.Component]: + def as_gradio_component( + self, + ) -> Optional[gr.components.Component | list[gr.components.Component]]: """Return the gradio components responsible for events Note: in ideal scenario, this method shouldn't be necessary. diff --git a/libs/ktem/ktem/index/base.py b/libs/ktem/ktem/index/base.py index 50bdd9e..5183762 100644 --- a/libs/ktem/ktem/index/base.py +++ b/libs/ktem/ktem/index/base.py @@ -1,6 +1,6 @@ import abc import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional if TYPE_CHECKING: from ktem.app import BasePage @@ -57,7 +57,7 @@ class BaseIndex(abc.ABC): self._app = app self.id = id self.name = name - self._config = config # admin settings + self.config = config # admin settings def on_create(self): """Create the index for the first time""" @@ -121,7 +121,7 @@ class BaseIndex(abc.ABC): ... def get_retriever_pipelines( - self, settings: dict, selected: Optional[list] + self, settings: dict, selected: Any = None ) -> list["BaseComponent"]: """Return the retriever pipelines to retrieve the entity from the index""" return [] diff --git a/libs/ktem/ktem/index/file/base.py b/libs/ktem/ktem/index/file/base.py index 5f8e6f4..4f28f51 100644 --- a/libs/ktem/ktem/index/file/base.py +++ b/libs/ktem/ktem/index/file/base.py @@ -127,3 +127,11 @@ class BaseFileIndexIndexing(BaseComponent): the absolute file storage path to the file """ raise NotImplementedError + + def warning(self, msg): + """Log a warning message + + Args: + msg: the message to log + """ + print(msg) diff --git a/libs/ktem/ktem/index/file/index.py b/libs/ktem/ktem/index/file/index.py index ab1f35a..5fe3955 100644 --- a/libs/ktem/ktem/index/file/index.py +++ b/libs/ktem/ktem/index/file/index.py @@ -13,7 +13,6 @@ from theflow.utils.modules import import_dotted_string from kotaemon.storages import BaseDocumentStore, BaseVectorStore from .base import BaseFileIndexIndexing, BaseFileIndexRetriever -from .ui import FileIndexPage, FileSelector class FileIndex(BaseIndex): @@ -77,9 +76,15 @@ class FileIndex(BaseIndex): self._indexing_pipeline_cls: Type[BaseFileIndexIndexing] self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]] + self._selector_ui_cls: Type + self._selector_ui: Any = None + self._index_ui_cls: Type + self._index_ui: Any = None self._setup_indexing_cls() self._setup_retriever_cls() + self._setup_file_index_ui_cls() + self._setup_file_selector_ui_cls() self._default_settings: dict[str, dict] = {} self._setting_mappings: dict[str, dict] = {} @@ -91,14 +96,14 @@ class FileIndex(BaseIndex): The indexing class will is retrieved from the following order. Stop at the first order found: - - `FILE_INDEX_PIPELINE` in self._config + - `FILE_INDEX_PIPELINE` in self.config - `FILE_INDEX_{id}_PIPELINE` in the flowsettings - `FILE_INDEX_PIPELINE` in the flowsettings - The default .pipelines.IndexDocumentPipeline """ - if "FILE_INDEX_PIPELINE" in self._config: + if "FILE_INDEX_PIPELINE" in self.config: self._indexing_pipeline_cls = import_dotted_string( - self._config["FILE_INDEX_PIPELINE"], safe=False + self.config["FILE_INDEX_PIPELINE"], safe=False ) return @@ -125,15 +130,15 @@ class FileIndex(BaseIndex): The retriever classes will is retrieved from the following order. Stop at the first order found: - - `FILE_INDEX_RETRIEVER_PIPELINES` in self._config + - `FILE_INDEX_RETRIEVER_PIPELINES` in self.config - `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings - `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings - The default .pipelines.DocumentRetrievalPipeline """ - if "FILE_INDEX_RETRIEVER_PIPELINES" in self._config: + if "FILE_INDEX_RETRIEVER_PIPELINES" in self.config: self._retriever_pipeline_cls = [ import_dotted_string(each, safe=False) - for each in self._config["FILE_INDEX_RETRIEVER_PIPELINES"] + for each in self.config["FILE_INDEX_RETRIEVER_PIPELINES"] ] return @@ -157,6 +162,76 @@ class FileIndex(BaseIndex): self._retriever_pipeline_cls = [DocumentRetrievalPipeline] + def _setup_file_selector_ui_cls(self): + """Retrieve the file selector UI for the file index + + There can be multiple retriever classes. + + The retriever classes will is retrieved from the following order. Stop at the + first order found: + - `FILE_INDEX_SELECTOR_UI` in self.config + - `FILE_INDEX_{id}_SELECTOR_UI` in the flowsettings + - `FILE_INDEX_SELECTOR_UI` in the flowsettings + - The default .ui.FileSelector + """ + if "FILE_INDEX_SELECTOR_UI" in self.config: + self._selector_ui_cls = import_dotted_string( + self.config["FILE_INDEX_SELECTOR_UI"], safe=False + ) + return + + if hasattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"): + self._selector_ui_cls = import_dotted_string( + getattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"), + safe=False, + ) + return + + if hasattr(flowsettings, "FILE_INDEX_SELECTOR_UI"): + self._selector_ui_cls = import_dotted_string( + getattr(flowsettings, "FILE_INDEX_SELECTOR_UI"), safe=False + ) + return + + from .ui import FileSelector + + self._selector_ui_cls = FileSelector + + def _setup_file_index_ui_cls(self): + """Retrieve the Index UI class + + There can be multiple retriever classes. + + The retriever classes will is retrieved from the following order. Stop at the + first order found: + - `FILE_INDEX_UI` in self.config + - `FILE_INDEX_{id}_UI` in the flowsettings + - `FILE_INDEX_UI` in the flowsettings + - The default .ui.FileIndexPage + """ + if "FILE_INDEX_UI" in self.config: + self._index_ui_cls = import_dotted_string( + self.config["FILE_INDEX_UI"], safe=False + ) + return + + if hasattr(flowsettings, f"FILE_INDEX_{self.id}_UI"): + self._index_ui_cls = import_dotted_string( + getattr(flowsettings, f"FILE_INDEX_{self.id}_UI"), + safe=False, + ) + return + + if hasattr(flowsettings, "FILE_INDEX_UI"): + self._index_ui_cls = import_dotted_string( + getattr(flowsettings, "FILE_INDEX_UI"), safe=False + ) + return + + from .ui import FileIndexPage + + self._index_ui_cls = FileIndexPage + def on_create(self): """Create the index for the first time @@ -165,6 +240,13 @@ class FileIndex(BaseIndex): 2. Create the vectorstore 3. Create the docstore """ + file_types_str = self.config.get( + "supported_file_types", + self.get_admin_settings()["supported_file_types"]["value"], + ) + file_types = [each.strip() for each in file_types_str.split(",")] + self.config["supported_file_types"] = file_types + self._resources["Source"].metadata.create_all(engine) # type: ignore self._resources["Index"].metadata.create_all(engine) # type: ignore self._fs_path.mkdir(parents=True, exist_ok=True) @@ -180,10 +262,14 @@ class FileIndex(BaseIndex): shutil.rmtree(self._fs_path) def get_selector_component_ui(self): - return FileSelector(self._app, self) + if self._selector_ui is None: + self._selector_ui = self._selector_ui_cls(self._app, self) + return self._selector_ui def get_index_page_ui(self): - return FileIndexPage(self._app, self) + if self._index_ui is None: + self._index_ui = self._index_ui_cls(self._app, self) + return self._index_ui def get_user_settings(self): if self._default_settings: @@ -210,7 +296,31 @@ class FileIndex(BaseIndex): "value": embedding_default, "component": "dropdown", "choices": embedding_choices, - } + }, + "supported_file_types": { + "name": "Supported file types", + "value": ( + "image, .pdf, .txt, .csv, .xlsx, .doc, .docx, .pptx, .html, .zip" + ), + "component": "text", + }, + "max_file_size": { + "name": "Max file size (MB) - set 0 to disable", + "value": 1000, + "component": "number", + }, + "max_number_of_files": { + "name": "Max number of files that can be indexed - set 0 to disable", + "value": 0, + "component": "number", + }, + "max_number_of_text_length": { + "name": ( + "Max amount of characters that can be indexed - set 0 to disable" + ), + "value": 0, + "component": "number", + }, } def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing: @@ -224,14 +334,15 @@ class FileIndex(BaseIndex): else: stripped_settings[key] = value - obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self._config) + obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config) obj.set_resources(resources=self._resources) return obj def get_retriever_pipelines( - self, settings: dict, selected: Optional[list] = None + self, settings: dict, selected: Any = None ) -> list["BaseFileIndexRetriever"]: + # retrieval settings prefix = f"index.options.{self.id}." stripped_settings = {} for key, value in settings.items(): @@ -240,9 +351,12 @@ class FileIndex(BaseIndex): else: stripped_settings[key] = value + # transform selected id + selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected) + retrievers = [] for cls in self._retriever_pipeline_cls: - obj = cls.get_pipeline(stripped_settings, self._config, selected) + obj = cls.get_pipeline(stripped_settings, self.config, selected_ids) if obj is None: continue obj.set_resources(self._resources) diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index 68b3a4d..b63d89c 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -9,6 +9,7 @@ from hashlib import sha256 from pathlib import Path from typing import Optional +import gradio as gr from ktem.components import embeddings, filestorage_path from ktem.db.models import engine from llama_index.vector_stores import ( @@ -18,7 +19,7 @@ from llama_index.vector_stores import ( MetadataFilters, ) from llama_index.vector_stores.types import VectorStoreQueryMode -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.orm import Session from theflow.settings import settings from theflow.utils.modules import import_dotted_string @@ -279,6 +280,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): to_index: list[str] = [] file_to_hash: dict[str, str] = {} errors = [] + to_update = [] for file_path in file_paths: abs_path = str(Path(file_path).resolve()) @@ -291,16 +293,26 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): statement = select(Source).where(Source.name == Path(abs_path).name) item = session.execute(statement).first() - if item and not reindex: - errors.append(Path(abs_path).name) - continue + if item: + if not reindex: + errors.append(Path(abs_path).name) + continue + else: + to_update.append(Path(abs_path).name) to_index.append(abs_path) if errors: + error_files = ", ".join(errors) + if len(error_files) > 100: + error_files = error_files[:80] + "..." print( - "Files already exist. Please rename/remove them or enable reindex.\n" - f"{errors}" + "Skip these files already exist. Please rename/remove them or " + f"enable reindex:\n{errors}" + ) + self.warning( + "Skip these files already exist. Please rename/remove them or " + f"enable reindex:\n{error_files}" ) if not to_index: @@ -310,9 +322,19 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): for path in to_index: shutil.copy(path, filestorage_path / file_to_hash[path]) - # prepare record info + # extract the file & prepare record info file_to_source: dict = {} + extraction_errors = [] + nodes = [] for file_path, file_hash in file_to_hash.items(): + if str(Path(file_path).resolve()) not in to_index: + continue + + extraction_result = self.file_ingestor(file_path) + if not extraction_result: + extraction_errors.append(Path(file_path).name) + continue + nodes.extend(extraction_result) source = Source( name=Path(file_path).name, path=file_hash, @@ -320,9 +342,23 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): ) file_to_source[file_path] = source - # extract the files - nodes = self.file_ingestor(to_index) - print("Extracted", len(to_index), "files into", len(nodes), "nodes") + if extraction_errors: + msg = "Failed to extract these files: {}".format( + ", ".join(extraction_errors) + ) + print(msg) + self.warning(msg) + + if not nodes: + return [], [] + + print( + "Extracted", + len(to_index) - len(extraction_errors), + "files into", + len(nodes), + "nodes", + ) # index the files print("Indexing the files into vector store") @@ -332,7 +368,11 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): # persist to the index print("Persisting the vector and the document into index") file_ids = [] + to_update = list(set(to_update)) with Session(engine) as session: + if to_update: + session.execute(delete(Source).where(Source.name.in_(to_update))) + for source in file_to_source.values(): session.add(source) session.commit() @@ -404,3 +444,6 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): super().set_resources(resources) self.indexing_vector_pipeline.vector_store = self._VS self.indexing_vector_pipeline.doc_store = self._DS + + def warning(self, msg): + gr.Warning(msg) diff --git a/libs/ktem/ktem/index/file/ui.py b/libs/ktem/ktem/index/file/ui.py index 9da2b4a..11d491f 100644 --- a/libs/ktem/ktem/index/file/ui.py +++ b/libs/ktem/ktem/index/file/ui.py @@ -1,29 +1,48 @@ import os import tempfile +from pathlib import Path import gradio as gr import pandas as pd +from gradio.data_classes import FileData +from gradio.utils import NamedString from ktem.app import BasePage from ktem.db.engine import engine from sqlalchemy import select from sqlalchemy.orm import Session +class File(gr.File): + """Subclass from gr.File to maintain the original filename + + The issue happens when user uploads file with name like: !@#$%%^&*().pdf + """ + + def _process_single_file(self, f: FileData) -> NamedString | bytes: + file_name = f.path + if self.type == "filepath": + if f.orig_name and Path(file_name).name != f.orig_name: + file_name = str(Path(file_name).parent / f.orig_name) + os.rename(f.path, file_name) + file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE) + file.name = file_name + return NamedString(file_name) + elif self.type == "binary": + with open(file_name, "rb") as file_data: + return file_data.read() + else: + raise ValueError( + "Unknown type: " + + str(type) + + ". Please choose from: 'filepath', 'binary'." + ) + + class DirectoryUpload(BasePage): - def __init__(self, app): - self._app = app - self._supported_file_types = [ - "image", - ".pdf", - ".txt", - ".csv", - ".xlsx", - ".doc", - ".docx", - ".pptx", - ".html", - ".zip", - ] + def __init__(self, app, index): + super().__init__(app) + self._index = index + self._supported_file_types = self._index.config.get("supported_file_types", []) self.on_building_ui() def on_building_ui(self): @@ -50,18 +69,7 @@ class FileIndexPage(BasePage): def __init__(self, app, index): super().__init__(app) self._index = index - self._supported_file_types = [ - "image", - ".pdf", - ".txt", - ".csv", - ".xlsx", - ".doc", - ".docx", - ".pptx", - ".html", - ".zip", - ] + self._supported_file_types = self._index.config.get("supported_file_types", []) self.selected_panel_false = "Selected file: (please select above)" self.selected_panel_true = "Selected file: {name}" # TODO: on_building_ui is not correctly named if it's always called in @@ -69,13 +77,32 @@ class FileIndexPage(BasePage): self.public_events = [f"onFileIndex{index.id}Changed"] self.on_building_ui() + def upload_instruction(self) -> str: + msgs = [] + if self._supported_file_types: + msgs.append( + f"- Supported file types: {', '.join(self._supported_file_types)}" + ) + + if max_file_size := self._index.config.get("max_file_size", 0): + msgs.append(f"- Maximum file size: {max_file_size} MB") + + if max_number_of_files := self._index.config.get("max_number_of_files", 0): + msgs.append(f"- The index can have maximum {max_number_of_files} files") + + if msgs: + return "\n".join(msgs) + + return "" + def on_building_ui(self): """Build the UI of the app""" - with gr.Accordion(label="File upload", open=False): - gr.Markdown( - f"Supported file types: {', '.join(self._supported_file_types)}", - ) - self.files = gr.File( + with gr.Accordion(label="File upload", open=True) as self.upload: + msg = self.upload_instruction() + if msg: + gr.Markdown(msg) + + self.files = File( file_types=self._supported_file_types, file_count="multiple", container=False, @@ -98,18 +125,20 @@ class FileIndexPage(BasePage): interactive=False, ) - with gr.Row(): + with gr.Row() as self.selection_info: self.selected_file_id = gr.State(value=None) self.selected_panel = gr.Markdown(self.selected_panel_false) self.deselect_button = gr.Button("Deselect", visible=False) - with gr.Row(): + with gr.Row() as self.tools: with gr.Column(): self.view_button = gr.Button("View Text (WIP)") with gr.Column(): self.delete_button = gr.Button("Delete") with gr.Row(): - self.delete_yes = gr.Button("Confirm Delete", visible=False) + self.delete_yes = gr.Button( + "Confirm Delete", variant="primary", visible=False + ) self.delete_no = gr.Button("Cancel", visible=False) def on_subscribe_public_events(self): @@ -242,10 +271,12 @@ class FileIndexPage(BasePage): self._app.settings_state, ], outputs=[self.file_output], + concurrency_limit=20, ).then( fn=self.list_file, inputs=None, outputs=[self.file_list_state, self.file_list], + concurrency_limit=20, ) for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): onUploaded = onUploaded.then(**event) @@ -274,6 +305,15 @@ class FileIndexPage(BasePage): selected_files: the list of files already selected settings: the settings of the app """ + if not files: + gr.Info("No uploaded file") + return gr.update() + + errors = self.validate(files) + if errors: + gr.Warning(", ".join(errors)) + return gr.update() + gr.Info(f"Start indexing {len(files)} files...") # get the pipeline @@ -409,6 +449,35 @@ class FileIndexPage(BasePage): name=list_files["name"][ev.index[0]] ) + def validate(self, files: list[str]): + """Validate if the files are valid""" + paths = [Path(file) for file in files] + errors = [] + if max_file_size := self._index.config.get("max_file_size", 0): + errors_max_size = [] + for path in paths: + if path.stat().st_size > max_file_size * 1e6: + errors_max_size.append(path.name) + if errors_max_size: + str_errors = ", ".join(errors_max_size) + if len(str_errors) > 60: + str_errors = str_errors[:55] + "..." + errors.append( + f"Maximum file size ({max_file_size} MB) exceeded: {str_errors}" + ) + + if max_number_of_files := self._index.config.get("max_number_of_files", 0): + with Session(engine) as session: + current_num_files = session.query( + self._index._db_tables["Source"].id + ).count() + if len(paths) + current_num_files > max_number_of_files: + errors.append( + f"Maximum number of files ({max_number_of_files}) will be exceeded" + ) + + return errors + class FileSelector(BasePage): """File selector UI in the Chat page""" @@ -430,6 +499,9 @@ class FileSelector(BasePage): def as_gradio_component(self): return self.selector + def get_selected_ids(self, selected): + return selected + def load_files(self, selected_files): options = [] available_ids = [] diff --git a/libs/ktem/ktem/index/manager.py b/libs/ktem/ktem/index/manager.py index 72c4f99..af1c8d4 100644 --- a/libs/ktem/ktem/index/manager.py +++ b/libs/ktem/ktem/index/manager.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Optional, Type from ktem.db.models import engine from sqlmodel import Session, select @@ -49,15 +49,19 @@ class IndexManager: Returns: BaseIndex: the index object """ + index_cls = import_dotted_string(index_type, safe=False) + index = index_cls(app=self._app, id=id, name=name, config=config) + index.on_create() + with Session(engine) as session: - index_entry = Index(id=id, name=name, config=config, index_type=index_type) + index_entry = Index( + id=index.id, name=index.name, config=index.config, index_type=index_type + ) session.add(index_entry) session.commit() session.refresh(index_entry) - index_cls = import_dotted_string(index_type, safe=False) - index = index_cls(app=self._app, id=id, name=name, config=config) - index.on_create() + index.id = index_entry.id return index @@ -77,7 +81,7 @@ class IndexManager: self._indices.append(index) return index - def exists(self, id: int) -> bool: + def exists(self, id: Optional[int] = None, name: Optional[str] = None) -> bool: """Check if the index exists Args: @@ -86,9 +90,19 @@ class IndexManager: Returns: bool: True if the index exists, False otherwise """ - with Session(engine) as session: - index = session.get(Index, id) - return index is not None + if id: + with Session(engine) as session: + index = session.get(Index, id) + return index is not None + + if name: + with Session(engine) as session: + index = session.exec( + select(Index).where(Index.name == name) + ).one_or_none() + return index is not None + + return False def on_application_startup(self): """This method is called by the base application when the application starts diff --git a/libs/ktem/ktem/main.py b/libs/ktem/ktem/main.py index c375ed7..1d76d04 100644 --- a/libs/ktem/ktem/main.py +++ b/libs/ktem/ktem/main.py @@ -27,7 +27,7 @@ class App(BaseApp): if self.f_user_management: from ktem.pages.login import LoginPage - with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]: + with gr.Tab("Welcome", elem_id="login-tab") as self._tabs["login-tab"]: self.login_page = LoginPage(self) with gr.Tab( @@ -62,6 +62,9 @@ class App(BaseApp): def on_subscribe_public_events(self): if self.f_user_management: + from ktem.db.engine import engine + from ktem.db.models import User + from sqlmodel import Session, select def signed_in_out(user_id): if not user_id: @@ -73,14 +76,31 @@ class App(BaseApp): ) for k in self._tabs.keys() ) - return list( - ( - gr.update(visible=True) - if k != "login-tab" - else gr.update(visible=False) - ) - for k in self._tabs.keys() - ) + + with Session(engine) as session: + user = session.exec(select(User).where(User.id == user_id)).first() + if user is None: + return list( + ( + gr.update(visible=True) + if k == "login-tab" + else gr.update(visible=False) + ) + for k in self._tabs.keys() + ) + + is_admin = user.admin + + tabs_update = [] + for k in self._tabs.keys(): + if k == "login-tab": + tabs_update.append(gr.update(visible=False)) + elif k == "admin-tab": + tabs_update.append(gr.update(visible=is_admin)) + else: + tabs_update.append(gr.update(visible=True)) + + return tabs_update self.subscribe_event( name="onSignIn", diff --git a/libs/ktem/ktem/pages/admin/user.py b/libs/ktem/ktem/pages/admin/user.py index 519fb0f..6411b30 100644 --- a/libs/ktem/ktem/pages/admin/user.py +++ b/libs/ktem/ktem/pages/admin/user.py @@ -40,7 +40,7 @@ def validate_username(usn): if len(usn) > 32: errors.append("Username must be at most 32 characters long") - if not usn.strip("_").isalnum(): + if not usn.replace("_", "").isalnum(): errors.append( "Username must contain only alphanumeric characters and underscores" ) @@ -97,8 +97,6 @@ def validate_password(pwd, pwd_cnf): class UserManagement(BasePage): def __init__(self, app): self._app = app - self.selected_panel_false = "Selected user: (please select above)" - self.selected_panel_true = "Selected user: {name}" self.on_building_ui() if hasattr(flowsettings, "KH_FEATURE_USER_MANAGEMENT_ADMIN") and hasattr( @@ -126,7 +124,38 @@ class UserManagement(BasePage): gr.Info(f'User "{usn}" created successfully') def on_building_ui(self): - with gr.Accordion(label="Create user", open=False): + with gr.Tab(label="User list"): + self.state_user_list = gr.State(value=None) + self.user_list = gr.DataFrame( + headers=["id", "name", "admin"], + interactive=False, + ) + + with gr.Group(visible=False) as self._selected_panel: + self.selected_user_id = gr.Number(value=-1, visible=False) + self.usn_edit = gr.Textbox(label="Username") + with gr.Row(): + self.pwd_edit = gr.Textbox(label="Change password", type="password") + self.pwd_cnf_edit = gr.Textbox( + label="Confirm change password", + type="password", + ) + self.admin_edit = gr.Checkbox(label="Admin") + + with gr.Row() as self._selected_panel_btn: + with gr.Column(): + self.btn_edit_save = gr.Button("Save") + with gr.Column(): + self.btn_delete = gr.Button("Delete") + with gr.Row(): + self.btn_delete_yes = gr.Button( + "Confirm delete", variant="primary", visible=False + ) + self.btn_delete_no = gr.Button("Cancel", visible=False) + with gr.Column(): + self.btn_close = gr.Button("Close") + + with gr.Tab(label="Create user"): self.usn_new = gr.Textbox(label="Username", interactive=True) self.pwd_new = gr.Textbox( label="Password", type="password", interactive=True @@ -139,52 +168,28 @@ class UserManagement(BasePage): gr.Markdown(PASSWORD_RULE) self.btn_new = gr.Button("Create user") - gr.Markdown("## User list") - self.btn_list_user = gr.Button("Refresh user list") - self.state_user_list = gr.State(value=None) - self.user_list = gr.DataFrame( - headers=["id", "name", "admin"], - interactive=False, - ) - - with gr.Row(): - self.selected_user_id = gr.State(value=None) - self.selected_panel = gr.Markdown(self.selected_panel_false) - self.deselect_button = gr.Button("Deselect", visible=False) - - with gr.Group(): - self.btn_delete = gr.Button("Delete user") - with gr.Row(): - self.btn_delete_yes = gr.Button("Confirm", visible=False) - self.btn_delete_no = gr.Button("Cancel", visible=False) - - gr.Markdown("## User details") - self.usn_edit = gr.Textbox(label="Username") - self.pwd_edit = gr.Textbox(label="Password", type="password") - self.pwd_cnf_edit = gr.Textbox(label="Confirm password", type="password") - self.admin_edit = gr.Checkbox(label="Admin") - self.btn_edit_save = gr.Button("Save") - def on_register_events(self): self.btn_new.click( self.create_user, inputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new], - outputs=None, - ) - self.btn_list_user.click( - self.list_users, inputs=None, outputs=[self.state_user_list, self.user_list] + outputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new], + ).then( + self.list_users, + inputs=self._app.user_id, + outputs=[self.state_user_list, self.user_list], ) self.user_list.select( self.select_user, inputs=self.user_list, - outputs=[self.selected_user_id, self.selected_panel], + outputs=[self.selected_user_id], show_progress="hidden", ) - self.selected_panel.change( + self.selected_user_id.change( self.on_selected_user_change, inputs=[self.selected_user_id], outputs=[ - self.deselect_button, + self._selected_panel, + self._selected_panel_btn, # delete section self.btn_delete, self.btn_delete_yes, @@ -197,12 +202,6 @@ class UserManagement(BasePage): ], show_progress="hidden", ) - self.deselect_button.click( - lambda: (None, self.selected_panel_false), - inputs=None, - outputs=[self.selected_user_id, self.selected_panel], - show_progress="hidden", - ) self.btn_delete.click( self.on_btn_delete_click, inputs=[self.selected_user_id], @@ -211,9 +210,13 @@ class UserManagement(BasePage): ) self.btn_delete_yes.click( self.delete_user, - inputs=[self.selected_user_id], - outputs=[self.selected_user_id, self.selected_panel], + inputs=[self._app.user_id, self.selected_user_id], + outputs=[self.selected_user_id], show_progress="hidden", + ).then( + self.list_users, + inputs=self._app.user_id, + outputs=[self.state_user_list, self.user_list], ) self.btn_delete_no.click( lambda: ( @@ -234,21 +237,53 @@ class UserManagement(BasePage): self.pwd_cnf_edit, self.admin_edit, ], - outputs=None, + outputs=[self.pwd_edit, self.pwd_cnf_edit], show_progress="hidden", + ).then( + self.list_users, + inputs=self._app.user_id, + outputs=[self.state_user_list, self.user_list], + ) + self.btn_close.click( + lambda: -1, + outputs=[self.selected_user_id], + ) + + def on_subscribe_public_events(self): + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.list_users, + "inputs": [self._app.user_id], + "outputs": [self.state_user_list, self.user_list], + }, + ) + self._app.subscribe_event( + name="onSignOut", + definition={ + "fn": lambda: ("", "", "", None, None, -1), + "outputs": [ + self.usn_new, + self.pwd_new, + self.pwd_cnf_new, + self.state_user_list, + self.user_list, + self.selected_user_id, + ], + }, ) def create_user(self, usn, pwd, pwd_cnf): errors = validate_username(usn) if errors: gr.Warning(errors) - return + return usn, pwd, pwd_cnf errors = validate_password(pwd, pwd_cnf) print(errors) if errors: gr.Warning(errors) - return + return usn, pwd, pwd_cnf with Session(engine) as session: statement = select(User).where(User.username_lower == usn.lower()) @@ -265,8 +300,22 @@ class UserManagement(BasePage): session.commit() gr.Info(f'User "{usn}" created successfully') - def list_users(self): + return "", "", "" + + def list_users(self, user_id): + if user_id is None: + return [], pd.DataFrame.from_records( + [{"id": "-", "username": "-", "admin": "-"}] + ) + with Session(engine) as session: + statement = select(User).where(User.id == user_id) + user = session.exec(statement).one() + if not user.admin: + return [], pd.DataFrame.from_records( + [{"id": "-", "username": "-", "admin": "-"}] + ) + statement = select(User) results = [ {"id": user.id, "username": user.username, "admin": user.admin} @@ -284,18 +333,17 @@ class UserManagement(BasePage): def select_user(self, user_list, ev: gr.SelectData): if ev.value == "-" and ev.index[0] == 0: gr.Info("No user is loaded. Please refresh the user list") - return None, self.selected_panel_false + return -1 if not ev.selected: - return None, self.selected_panel_false + return -1 - return user_list["id"][ev.index[0]], self.selected_panel_true.format( - name=user_list["username"][ev.index[0]] - ) + return user_list["id"][ev.index[0]] def on_selected_user_change(self, selected_user_id): - if selected_user_id is None: - deselect_button = gr.update(visible=False) + if selected_user_id == -1: + _selected_panel = gr.update(visible=False) + _selected_panel_btn = gr.update(visible=False) btn_delete = gr.update(visible=True) btn_delete_yes = gr.update(visible=False) btn_delete_no = gr.update(visible=False) @@ -304,7 +352,8 @@ class UserManagement(BasePage): pwd_cnf_edit = gr.update(value="") admin_edit = gr.update(value=False) else: - deselect_button = gr.update(visible=True) + _selected_panel = gr.update(visible=True) + _selected_panel_btn = gr.update(visible=True) btn_delete = gr.update(visible=True) btn_delete_yes = gr.update(visible=False) btn_delete_no = gr.update(visible=False) @@ -319,7 +368,8 @@ class UserManagement(BasePage): admin_edit = gr.update(value=user.admin) return ( - deselect_button, + _selected_panel, + _selected_panel_btn, btn_delete, btn_delete_yes, btn_delete_no, @@ -344,17 +394,16 @@ class UserManagement(BasePage): return btn_delete, btn_delete_yes, btn_delete_no def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin): - if usn: - errors = validate_username(usn) - if errors: - gr.Warning(errors) - return + errors = validate_username(usn) + if errors: + gr.Warning(errors) + return pwd, pwd_cnf if pwd: errors = validate_password(pwd, pwd_cnf) if errors: gr.Warning(errors) - return + return pwd, pwd_cnf with Session(engine) as session: statement = select(User).where(User.id == int(selected_user_id)) @@ -367,11 +416,17 @@ class UserManagement(BasePage): session.commit() gr.Info(f'User "{usn}" updated successfully') - def delete_user(self, selected_user_id): + return "", "" + + def delete_user(self, current_user, selected_user_id): + if current_user == selected_user_id: + gr.Warning("You cannot delete yourself") + return selected_user_id + with Session(engine) as session: statement = select(User).where(User.id == int(selected_user_id)) user = session.exec(statement).one() session.delete(user) session.commit() gr.Info(f'User "{user.username}" deleted successfully') - return None, self.selected_panel_false + return -1 diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index d2bba87..a83bd83 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -7,8 +7,10 @@ from ktem.app import BasePage from ktem.components import reasonings from ktem.db.models import Conversation, engine from sqlmodel import Session, select +from theflow.settings import settings as flowsettings from .chat_panel import ChatPanel +from .chat_suggestion import ChatSuggestion from .common import STATE from .control import ConversationControl from .report import ReportIssue @@ -26,24 +28,39 @@ class ChatPage(BasePage): with gr.Column(scale=1): self.chat_control = ConversationControl(self._app) + if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False): + self.chat_suggestion = ChatSuggestion(self._app) + for index in self._app.index_manager.indices: - index.selector = -1 + index.selector = None index_ui = index.get_selector_component_ui() if not index_ui: + # the index doesn't have a selector UI component continue - index_ui.unrender() + index_ui.unrender() # need to rerender later within Accordion with gr.Accordion(label=f"{index.name} Index", open=False): index_ui.render() gr_index = index_ui.as_gradio_component() if gr_index: - index.selector = len(self._indices_input) - self._indices_input.append(gr_index) + if isinstance(gr_index, list): + index.selector = tuple( + range( + len(self._indices_input), + len(self._indices_input) + len(gr_index), + ) + ) + self._indices_input.extend(gr_index) + else: + index.selector = len(self._indices_input) + self._indices_input.append(gr_index) setattr(self, f"_index_{index.id}", index_ui) self.report_issue = ReportIssue(self._app) + with gr.Column(scale=6): self.chat_panel = ChatPanel(self._app) + with gr.Column(scale=3): with gr.Accordion(label="Information panel", open=True): self.info_panel = gr.HTML(elem_id="chat-info-panel") @@ -54,11 +71,24 @@ class ChatPage(BasePage): self.chat_panel.text_input.submit, self.chat_panel.submit_btn.click, ], - fn=self.chat_panel.submit_msg, - inputs=[self.chat_panel.text_input, self.chat_panel.chatbot], - outputs=[self.chat_panel.text_input, self.chat_panel.chatbot], + fn=self.submit_msg, + inputs=[ + self.chat_panel.text_input, + self.chat_panel.chatbot, + self._app.user_id, + self.chat_control.conversation_id, + self.chat_control.conversation_rn, + ], + outputs=[ + self.chat_panel.text_input, + self.chat_panel.chatbot, + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + ], + concurrency_limit=20, show_progress="hidden", - ).then( + ).success( fn=self.chat_fn, inputs=[ self.chat_control.conversation_id, @@ -72,6 +102,7 @@ class ChatPage(BasePage): self.info_panel, self.chat_state, ], + concurrency_limit=20, show_progress="minimal", ).then( fn=self.update_data_source, @@ -82,6 +113,7 @@ class ChatPage(BasePage): ] + self._indices_input, outputs=None, + concurrency_limit=20, ) self.chat_panel.regen_btn.click( @@ -98,6 +130,7 @@ class ChatPage(BasePage): self.info_panel, self.chat_state, ], + concurrency_limit=20, show_progress="minimal", ).then( fn=self.update_data_source, @@ -108,6 +141,7 @@ class ChatPage(BasePage): ] + self._indices_input, outputs=None, + concurrency_limit=20, ) self.chat_panel.chatbot.like( @@ -116,7 +150,12 @@ class ChatPage(BasePage): outputs=None, ) - self.chat_control.conversation.change( + self.chat_control.btn_new.click( + self.chat_control.new_conv, + inputs=self._app.user_id, + outputs=[self.chat_control.conversation_id, self.chat_control.conversation], + show_progress="hidden", + ).then( self.chat_control.select_conv, inputs=[self.chat_control.conversation], outputs=[ @@ -124,12 +163,71 @@ class ChatPage(BasePage): self.chat_control.conversation, self.chat_control.conversation_rn, self.chat_panel.chatbot, + self.info_panel, self.chat_state, ] + self._indices_input, show_progress="hidden", ) + self.chat_control.btn_del.click( + lambda id: self.toggle_delete(id), + inputs=[self.chat_control.conversation_id], + outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], + ) + self.chat_control.btn_del_conf.click( + self.chat_control.delete_conv, + inputs=[self.chat_control.conversation_id, self._app.user_id], + outputs=[self.chat_control.conversation_id, self.chat_control.conversation], + show_progress="hidden", + ).then( + self.chat_control.select_conv, + inputs=[self.chat_control.conversation], + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + self.info_panel, + ] + + self._indices_input, + show_progress="hidden", + ).then( + lambda: self.toggle_delete(""), + outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], + ) + self.chat_control.btn_del_cnl.click( + lambda: self.toggle_delete(""), + outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], + ) + self.chat_control.conversation_rn_btn.click( + self.chat_control.rename_conv, + inputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation_rn, + self._app.user_id, + ], + outputs=[self.chat_control.conversation, self.chat_control.conversation], + show_progress="hidden", + ) + + self.chat_control.conversation.select( + self.chat_control.select_conv, + inputs=[self.chat_control.conversation], + outputs=[ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + self.info_panel, + ] + + self._indices_input, + show_progress="hidden", + ).then( + lambda: self.toggle_delete(""), + outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm], + ) + self.report_issue.report_btn.click( self.report_issue.report, inputs=[ @@ -140,11 +238,77 @@ class ChatPage(BasePage): self.chat_panel.chatbot, self._app.settings_state, self._app.user_id, + self.info_panel, self.chat_state, ] + self._indices_input, outputs=None, ) + if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False): + self.chat_suggestion.example.select( + self.chat_suggestion.select_example, + outputs=[self.chat_panel.text_input], + show_progress="hidden", + ) + + def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name): + """Submit a message to the chatbot""" + if not chat_input: + raise ValueError("Input is empty") + + if not conv_id: + id_, update = self.chat_control.new_conv(user_id) + with Session(engine) as session: + statement = select(Conversation).where(Conversation.id == id_) + name = session.exec(statement).one().name + new_conv_id = id_ + conv_update = update + new_conv_name = name + else: + new_conv_id = conv_id + conv_update = gr.update() + new_conv_name = conv_name + + return ( + "", + chat_history + [(chat_input, None)], + new_conv_id, + conv_update, + new_conv_name, + ) + + def toggle_delete(self, conv_id): + if conv_id: + return gr.update(visible=False), gr.update(visible=True) + else: + return gr.update(visible=True), gr.update(visible=False) + + def on_subscribe_public_events(self): + if self._app.f_user_management: + self._app.subscribe_event( + name="onSignIn", + definition={ + "fn": self.chat_control.reload_conv, + "inputs": [self._app.user_id], + "outputs": [self.chat_control.conversation], + "show_progress": "hidden", + }, + ) + + self._app.subscribe_event( + name="onSignOut", + definition={ + "fn": lambda: self.chat_control.select_conv(""), + "outputs": [ + self.chat_control.conversation_id, + self.chat_control.conversation, + self.chat_control.conversation_rn, + self.chat_panel.chatbot, + ] + + self._indices_input, + "show_progress": "hidden", + }, + ) def update_data_source(self, convo_id, messages, state, *selecteds): """Update the data source""" @@ -154,8 +318,12 @@ class ChatPage(BasePage): selecteds_ = {} for index in self._app.index_manager.indices: - if index.selector != -1: + if index.selector is None: + continue + if isinstance(index.selector, int): selecteds_[str(index.id)] = selecteds[index.selector] + else: + selecteds_[str(index.id)] = [selecteds[i] for i in index.selector] with Session(engine) as session: statement = select(Conversation).where(Conversation.id == convo_id) @@ -205,8 +373,11 @@ class ChatPage(BasePage): retrievers = [] for index in self._app.index_manager.indices: index_selected = [] - if index.selector != -1: + if isinstance(index.selector, int): index_selected = selecteds[index.selector] + if isinstance(index.selector, tuple): + for i in index.selector: + index_selected.append(selecteds[i]) iretrievers = index.get_retriever_pipelines(settings, index_selected) retrievers += iretrievers @@ -250,7 +421,10 @@ class ChatPage(BasePage): break if "output" in response: - text += response["output"] + if response["output"] is None: + text = "" + else: + text += response["output"] if "evidence" in response: if response["evidence"] is None: diff --git a/libs/ktem/ktem/pages/chat/chat_suggestion.py b/libs/ktem/ktem/pages/chat/chat_suggestion.py new file mode 100644 index 0000000..23332c0 --- /dev/null +++ b/libs/ktem/ktem/pages/chat/chat_suggestion.py @@ -0,0 +1,26 @@ +import gradio as gr +from ktem.app import BasePage +from theflow.settings import settings as flowsettings + + +class ChatSuggestion(BasePage): + def __init__(self, app): + self._app = app + self.on_building_ui() + + def on_building_ui(self): + chat_samples = getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", []) + chat_samples = [[each] for each in chat_samples] + with gr.Accordion(label="Chat Suggestion", open=False) as self.accordion: + self.example = gr.DataFrame( + value=chat_samples, + headers=["Sample"], + interactive=False, + wrap=True, + ) + + def as_gradio_component(self): + return self.example + + def select_example(self, ev: gr.SelectData): + return ev.value diff --git a/libs/ktem/ktem/pages/chat/control.py b/libs/ktem/ktem/pages/chat/control.py index e714112..f2ed99b 100644 --- a/libs/ktem/ktem/pages/chat/control.py +++ b/libs/ktem/ktem/pages/chat/control.py @@ -10,6 +10,17 @@ from .common import STATE logger = logging.getLogger(__name__) +def is_conv_name_valid(name): + """Check if the conversation name is valid""" + errors = [] + if len(name) == 0: + errors.append("Name cannot be empty") + elif len(name) > 40: + errors.append("Name cannot be longer than 40 characters") + + return "; ".join(errors) + + class ConversationControl(BasePage): """Manage conversation""" @@ -28,9 +39,17 @@ class ConversationControl(BasePage): interactive=True, ) - with gr.Row(): - self.conversation_new_btn = gr.Button(value="New", min_width=10) - self.conversation_del_btn = gr.Button(value="Delete", min_width=10) + with gr.Row() as self._new_delete: + self.btn_new = gr.Button(value="New", min_width=10) + self.btn_del = gr.Button(value="Delete", min_width=10) + + with gr.Row(visible=False) as self._delete_confirm: + self.btn_del_conf = gr.Button( + value="Delete", + variant="primary", + min_width=10, + ) + self.btn_del_cnl = gr.Button(value="Cancel", min_width=10) with gr.Row(): self.conversation_rn = gr.Text( @@ -52,48 +71,6 @@ class ConversationControl(BasePage): # outputs=[current_state], # ) - def on_subscribe_public_events(self): - if self._app.f_user_management: - self._app.subscribe_event( - name="onSignIn", - definition={ - "fn": self.reload_conv, - "inputs": [self._app.user_id], - "outputs": [self.conversation], - "show_progress": "hidden", - }, - ) - - self._app.subscribe_event( - name="onSignOut", - definition={ - "fn": self.reload_conv, - "inputs": [self._app.user_id], - "outputs": [self.conversation], - "show_progress": "hidden", - }, - ) - - def on_register_events(self): - self.conversation_new_btn.click( - self.new_conv, - inputs=self._app.user_id, - outputs=[self.conversation_id, self.conversation], - show_progress="hidden", - ) - self.conversation_del_btn.click( - self.delete_conv, - inputs=[self.conversation_id, self._app.user_id], - outputs=[self.conversation_id, self.conversation], - show_progress="hidden", - ) - self.conversation_rn_btn.click( - self.rename_conv, - inputs=[self.conversation_id, self.conversation_rn, self._app.user_id], - outputs=[self.conversation, self.conversation], - show_progress="hidden", - ) - def load_chat_history(self, user_id): """Reload chat history""" options = [] @@ -112,7 +89,7 @@ class ConversationControl(BasePage): def reload_conv(self, user_id): conv_list = self.load_chat_history(user_id) if conv_list: - return gr.update(value=conv_list[0][1], choices=conv_list) + return gr.update(value=None, choices=conv_list) else: return gr.update(value=None, choices=[]) @@ -133,10 +110,15 @@ class ConversationControl(BasePage): return id_, gr.update(value=id_, choices=history) def delete_conv(self, conversation_id, user_id): - """Create new chat""" + """Delete the selected conversation""" + if not conversation_id: + gr.Warning("No conversation selected.") + return None, gr.update() + if user_id is None: gr.Warning("Please sign in first (Settings → User Settings)") return None, gr.update() + with Session(engine) as session: statement = select(Conversation).where(Conversation.id == conversation_id) result = session.exec(statement).one() @@ -161,6 +143,7 @@ class ConversationControl(BasePage): name = result.name selected = result.data_source.get("selected", {}) chats = result.data_source.get("messages", []) + info_panel = "" state = result.data_source.get("state", STATE) except Exception as e: logger.warning(e) @@ -168,22 +151,36 @@ class ConversationControl(BasePage): name = "" selected = {} chats = [] + info_panel = "" state = STATE indices = [] for index in self._app.index_manager.indices: # assume that the index has selector - if index.selector == -1: + if index.selector is None: continue - indices.append(selected.get(str(index.id), [])) + if isinstance(index.selector, int): + indices.append(selected.get(str(index.id), [])) + if isinstance(index.selector, tuple): + indices.extend(selected.get(str(index.id), [[]] * len(index.selector))) - return id_, id_, name, chats, state, *indices + return id_, id_, name, chats, info_panel, state, *indices def rename_conv(self, conversation_id, new_name, user_id): """Rename the conversation""" if user_id is None: gr.Warning("Please sign in first (Settings → User Settings)") return gr.update(), "" + + if not conversation_id: + gr.Warning("No conversation selected.") + return gr.update(), "" + + errors = is_conv_name_valid(new_name) + if errors: + gr.Warning(errors) + return gr.update(), conversation_id + with Session(engine) as session: statement = select(Conversation).where(Conversation.id == conversation_id) result = session.exec(statement).one() diff --git a/libs/ktem/ktem/pages/chat/report.py b/libs/ktem/ktem/pages/chat/report.py index 25d83f8..dfe0301 100644 --- a/libs/ktem/ktem/pages/chat/report.py +++ b/libs/ktem/ktem/pages/chat/report.py @@ -48,13 +48,19 @@ class ReportIssue(BasePage): chat_history: list, settings: dict, user_id: Optional[int], + info_panel: str, chat_state: dict, - *selecteds + *selecteds, ): selecteds_ = {} for index in self._app.index_manager.indices: - if index.selector != -1: - selecteds_[str(index.id)] = selecteds[index.selector] + if index.selector is not None: + if isinstance(index.selector, int): + selecteds_[str(index.id)] = selecteds[index.selector] + elif isinstance(index.selector, tuple): + selecteds_[str(index.id)] = [selecteds[_] for _ in index.selector] + else: + print(f"Unknown selector type: {index.selector}") with Session(engine) as session: issue = IssueReport( @@ -66,6 +72,7 @@ class ReportIssue(BasePage): chat={ "conv_id": conv_id, "chat_history": chat_history, + "info_panel": info_panel, "chat_state": chat_state, "selecteds": selecteds_, }, diff --git a/libs/ktem/ktem/pages/login.py b/libs/ktem/ktem/pages/login.py index 6fe15d0..d5c57e5 100644 --- a/libs/ktem/ktem/pages/login.py +++ b/libs/ktem/ktem/pages/login.py @@ -31,11 +31,10 @@ class LoginPage(BasePage): self.on_building_ui() def on_building_ui(self): - gr.Markdown("Welcome to Kotaemon") - self.usn = gr.Textbox(label="Username") - self.pwd = gr.Textbox(label="Password", type="password") - self.btn_login = gr.Button("Login") - self._dummy = gr.State() + gr.Markdown("# Welcome to Kotaemon") + self.usn = gr.Textbox(label="Username", visible=False) + self.pwd = gr.Textbox(label="Password", type="password", visible=False) + self.btn_login = gr.Button("Login", visible=False) def on_register_events(self): onSignIn = gr.on( @@ -45,24 +44,56 @@ class LoginPage(BasePage): outputs=[self._app.user_id, self.usn, self.pwd], show_progress="hidden", js=signin_js, + ).then( + self.toggle_login_visibility, + inputs=[self._app.user_id], + outputs=[self.usn, self.pwd, self.btn_login], ) for event in self._app.get_event("onSignIn"): onSignIn = onSignIn.success(**event) + def toggle_login_visibility(self, user_id): + return ( + gr.update(visible=user_id is None), + gr.update(visible=user_id is None), + gr.update(visible=user_id is None), + ) + def _on_app_created(self): - self._app.app.load( - None, - inputs=None, - outputs=[self.usn, self.pwd], + onSignIn = self._app.app.load( + self.login, + inputs=[self.usn, self.pwd], + outputs=[self._app.user_id, self.usn, self.pwd], + show_progress="hidden", js=fetch_creds, + ).then( + self.toggle_login_visibility, + inputs=[self._app.user_id], + outputs=[self.usn, self.pwd, self.btn_login], + ) + for event in self._app.get_event("onSignIn"): + onSignIn = onSignIn.success(**event) + + def on_subscribe_public_events(self): + self._app.subscribe_event( + name="onSignOut", + definition={ + "fn": self.toggle_login_visibility, + "inputs": [self._app.user_id], + "outputs": [self.usn, self.pwd, self.btn_login], + "show_progress": "hidden", + }, ) def login(self, usn, pwd): + if not usn or not pwd: + return None, usn, pwd hashed_password = hashlib.sha256(pwd.encode()).hexdigest() with Session(engine) as session: stmt = select(User).where( - User.username_lower == usn.lower(), User.password == hashed_password + User.username_lower == usn.lower().strip(), + User.password == hashed_password, ) result = session.exec(stmt).all() if result: diff --git a/libs/ktem/ktem/pages/settings.py b/libs/ktem/ktem/pages/settings.py index 0fce2e8..20912cb 100644 --- a/libs/ktem/ktem/pages/settings.py +++ b/libs/ktem/ktem/pages/settings.py @@ -164,9 +164,14 @@ class SettingsPage(BasePage): show_progress="hidden", ) onSignOutClick = self.signout.click( - lambda: (None, "Current user: ___"), + lambda: (None, "Current user: ___", "", ""), inputs=None, - outputs=[self._user_id, self.current_name], + outputs=[ + self._user_id, + self.current_name, + self.password_change, + self.password_change_confirm, + ], show_progress="hidden", js=signout_js, ).then( @@ -192,8 +197,12 @@ class SettingsPage(BasePage): self.password_change_btn = gr.Button("Change password", interactive=True) def change_password(self, user_id, password, password_confirm): - if password != password_confirm: - gr.Warning("Password does not match") + from ktem.pages.admin.user import validate_password + + errors = validate_password(password, password_confirm) + if errors: + print(errors) + gr.Warning(errors) return password, password_confirm with Session(engine) as session: diff --git a/libs/ktem/ktem/reasoning/base.py b/libs/ktem/ktem/reasoning/base.py index 80cf016..6d6e486 100644 --- a/libs/ktem/ktem/reasoning/base.py +++ b/libs/ktem/ktem/reasoning/base.py @@ -34,12 +34,16 @@ class BaseReasoning(BaseComponent): @classmethod def get_pipeline( - cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None + cls, + user_settings: dict, + state: dict, + retrievers: Optional[list["BaseComponent"]] = None, ) -> "BaseReasoning": """Get the reasoning pipeline for the app to execute Args: user_setting: user settings + state: conversation state retrievers (list): List of retrievers """ return cls() diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 23d8363..082c20f 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -22,6 +22,8 @@ from kotaemon.indices.splitters import TokenSplitter from kotaemon.llms import ChatLLM, PromptTemplate from kotaemon.loaders.utils.gpt4v import stream_gpt4v +from .base import BaseReasoning + logger = logging.getLogger(__name__) EVIDENCE_MODE_TEXT = 0 @@ -204,7 +206,7 @@ class AnswerWithContextPipeline(BaseComponent): lang: str = "English" # support English and Japanese async def run( # type: ignore - self, question: str, evidence: str, evidence_mode: int = 0 + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs ) -> Document: """Answer the question based on the evidence @@ -336,7 +338,7 @@ class RewriteQuestionPipeline(BaseComponent): return Document(text=output) -class FullQAPipeline(BaseComponent): +class FullQAPipeline(BaseReasoning): """Question answering pipeline. Handle from question to answer""" class Config: @@ -352,6 +354,8 @@ class FullQAPipeline(BaseComponent): async def run( # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore ) -> Document: # type: ignore + import markdown + docs = [] doc_ids = [] if self.use_rewrite: @@ -364,12 +368,16 @@ class FullQAPipeline(BaseComponent): docs.append(doc) doc_ids.append(doc.doc_id) for doc in docs: + # TODO: a better approach to show the information + text = markdown.markdown( + doc.text, extensions=["markdown.extensions.tables"] + ) self.report_output( { "evidence": ( "
" f"{doc.metadata['file_name']}" - f"{doc.text}" + f"{text}" "

" ) } @@ -378,7 +386,12 @@ class FullQAPipeline(BaseComponent): evidence_mode, evidence = self.evidence_pipeline(docs).content answer = await self.answering_pipeline( - question=message, evidence=evidence, evidence_mode=evidence_mode + question=message, + history=history, + evidence=evidence, + evidence_mode=evidence_mode, + conv_id=conv_id, + **kwargs, ) # prepare citation @@ -388,14 +401,29 @@ class FullQAPipeline(BaseComponent): for quote in fact_with_evidence.substring_quote: for doc in docs: start_idx = doc.text.find(quote) - if start_idx >= 0: + if start_idx == -1: + continue + + end_idx = start_idx + len(quote) + + current_idx = start_idx + if "|" not in doc.text[start_idx:end_idx]: spans[doc.doc_id].append( - { - "start": start_idx, - "end": start_idx + len(quote), - } + {"start": start_idx, "end": end_idx} ) - break + else: + while doc.text[current_idx:end_idx].find("|") != -1: + match_idx = doc.text[current_idx:end_idx].find("|") + spans[doc.doc_id].append( + { + "start": current_idx, + "end": current_idx + match_idx, + } + ) + current_idx += match_idx + 2 + if current_idx > end_idx: + break + break id2docs = {doc.doc_id: doc for doc in docs} lack_evidence = True @@ -414,12 +442,15 @@ class FullQAPipeline(BaseComponent): if idx < len(ss) - 1: text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] text += id2docs[id].text[ss[-1]["end"] :] + text_out = markdown.markdown( + text, extensions=["markdown.extensions.tables"] + ) self.report_output( { "evidence": ( "
" f"{id2docs[id].metadata['file_name']}" - f"{text}" + f"{text_out}" "

" ) } @@ -434,12 +465,15 @@ class FullQAPipeline(BaseComponent): {"evidence": "Retrieved segments without matching evidence:\n"} ) for id in list(not_detected): + text_out = markdown.markdown( + id2docs[id].text, extensions=["markdown.extensions.tables"] + ) self.report_output( { "evidence": ( "
" f"{id2docs[id].metadata['file_name']}" - f"{id2docs[id].text}" + f"{text_out}" "

" ) }