Fix UI bugs (#8)
* Auto create conversation when the user starts * Add conversation rename rule check * Fix empty name during save * Confirm deleting conversation * Show warning if users don't select file when upload files in the File Index * Feedback when user uploads duplicated file * Limit the file types * Fix valid username * Allow login when username with leading and trailing whitespaces * Improve the user * Disable admin panel for non-admnin user * Refresh user lists after creating/deleting users * Auto logging in * Clear admin information upon signing out * Fix unable to receive uploaded filename that include special characters, like !@#$%^&*().pdf * Set upload validation for FileIndex * Improve user management UI/UIX * Show extraction error when indexing file * Return selected user -1 when signing out * Fix default supported file types in file index * Validate changing password * Allow the selector to contain mulitple gradio components * A more tolerable placeholder screen * Allow chat suggestion box * Increase concurrency limit * Make adobe loader optional * Use BaseReasoning --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
parent
43a18ba070
commit
ecf09b275f
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -452,6 +452,7 @@ $RECYCLE.BIN/
|
|||
.theflow/
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm
|
||||
*.py[coid]
|
||||
|
||||
logs/
|
||||
.gitsecret/keys/random_seed
|
||||
|
|
|
@ -52,7 +52,12 @@ repos:
|
|||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
[types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"]
|
||||
[
|
||||
types-PyYAML==6.0.12.11,
|
||||
"types-requests",
|
||||
"sqlmodel",
|
||||
"types-Markdown",
|
||||
]
|
||||
args: ["--check-untyped-defs", "--ignore-missing-imports"]
|
||||
exclude: "^templates/"
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
|
|
|
@ -104,18 +104,16 @@ class CitationPipeline(BaseComponent):
|
|||
print("CitationPipeline: invoking LLM")
|
||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish invoking LLM")
|
||||
if not llm_output.messages:
|
||||
return None
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if not llm_output.messages:
|
||||
return None
|
||||
|
||||
function_output = llm_output.messages[0].additional_kwargs["function_call"][
|
||||
"arguments"
|
||||
]
|
||||
output = QuestionAnswer.parse_raw(function_output)
|
||||
|
||||
return output
|
||||
|
||||
async def ainvoke(self, context: str, question: str):
|
||||
|
|
|
@ -5,7 +5,7 @@ from .docx_loader import DocxReader
|
|||
from .excel_loader import PandasExcelReader
|
||||
from .html_loader import HtmlReader
|
||||
from .mathpix_loader import MathpixPDFReader
|
||||
from .ocr_loader import OCRReader
|
||||
from .ocr_loader import ImageReader, OCRReader
|
||||
from .unstructured_loader import UnstructuredReader
|
||||
|
||||
__all__ = [
|
||||
|
@ -13,6 +13,7 @@ __all__ = [
|
|||
"BaseReader",
|
||||
"PandasExcelReader",
|
||||
"MathpixPDFReader",
|
||||
"ImageReader",
|
||||
"OCRReader",
|
||||
"DirectoryReader",
|
||||
"UnstructuredReader",
|
||||
|
|
|
@ -10,14 +10,6 @@ from llama_index.readers.base import BaseReader
|
|||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .utils.adobe import (
|
||||
generate_figure_captions,
|
||||
load_json,
|
||||
parse_figure_paths,
|
||||
parse_table_paths,
|
||||
request_adobe_service,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_VLM_ENDPOINT = (
|
||||
|
@ -74,6 +66,13 @@ class AdobeReader(BaseReader):
|
|||
includes 3 types: text, table, and image
|
||||
|
||||
"""
|
||||
from .utils.adobe import (
|
||||
generate_figure_captions,
|
||||
load_json,
|
||||
parse_figure_paths,
|
||||
parse_table_paths,
|
||||
request_adobe_service,
|
||||
)
|
||||
|
||||
filename = file.name
|
||||
filepath = str(Path(file).resolve())
|
||||
|
|
|
@ -125,3 +125,70 @@ class OCRReader(BaseReader):
|
|||
)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
class ImageReader(BaseReader):
|
||||
"""Read PDF using OCR, with high focus on table extraction
|
||||
|
||||
Example:
|
||||
```python
|
||||
>> from knowledgehub.loaders import OCRReader
|
||||
>> reader = OCRReader()
|
||||
>> documents = reader.load_data("path/to/pdf")
|
||||
```
|
||||
|
||||
Args:
|
||||
endpoint: URL to FullOCR endpoint. If not provided, will look for
|
||||
environment variable `OCR_READER_ENDPOINT` or use the default
|
||||
`knowledgehub.loaders.ocr_loader.DEFAULT_OCR_ENDPOINT`
|
||||
(http://127.0.0.1:8000/v2/ai/infer/)
|
||||
use_ocr: whether to use OCR to read text (e.g: from images, tables) in the PDF
|
||||
If False, only the table and text within table cells will be extracted.
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint: Optional[str] = None):
|
||||
"""Init the OCR reader with OCR endpoint (FullOCR pipeline)"""
|
||||
super().__init__()
|
||||
self.ocr_endpoint = endpoint or os.getenv(
|
||||
"OCR_READER_ENDPOINT", DEFAULT_OCR_ENDPOINT
|
||||
)
|
||||
|
||||
def load_data(
|
||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||
) -> List[Document]:
|
||||
"""Load data using OCR reader
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to PDF file
|
||||
debug_path (Path): Path to store debug image output
|
||||
artifact_path (Path): Path to OCR endpoints artifacts directory
|
||||
|
||||
Returns:
|
||||
List[Document]: list of documents extracted from the PDF file
|
||||
"""
|
||||
file_path = Path(file_path).resolve()
|
||||
|
||||
with file_path.open("rb") as content:
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": False}
|
||||
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
extra_info = extra_info or {}
|
||||
result = []
|
||||
for ocr_result in ocr_results:
|
||||
result.append(
|
||||
Document(
|
||||
content=ocr_result["csv_string"],
|
||||
metadata=extra_info,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
|
|
@ -229,7 +229,9 @@ class BasePage:
|
|||
def _on_app_created(self):
|
||||
"""Called when the app is created"""
|
||||
|
||||
def as_gradio_component(self) -> Optional[gr.components.Component]:
|
||||
def as_gradio_component(
|
||||
self,
|
||||
) -> Optional[gr.components.Component | list[gr.components.Component]]:
|
||||
"""Return the gradio components responsible for events
|
||||
|
||||
Note: in ideal scenario, this method shouldn't be necessary.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import abc
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ktem.app import BasePage
|
||||
|
@ -57,7 +57,7 @@ class BaseIndex(abc.ABC):
|
|||
self._app = app
|
||||
self.id = id
|
||||
self.name = name
|
||||
self._config = config # admin settings
|
||||
self.config = config # admin settings
|
||||
|
||||
def on_create(self):
|
||||
"""Create the index for the first time"""
|
||||
|
@ -121,7 +121,7 @@ class BaseIndex(abc.ABC):
|
|||
...
|
||||
|
||||
def get_retriever_pipelines(
|
||||
self, settings: dict, selected: Optional[list]
|
||||
self, settings: dict, selected: Any = None
|
||||
) -> list["BaseComponent"]:
|
||||
"""Return the retriever pipelines to retrieve the entity from the index"""
|
||||
return []
|
||||
|
|
|
@ -127,3 +127,11 @@ class BaseFileIndexIndexing(BaseComponent):
|
|||
the absolute file storage path to the file
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def warning(self, msg):
|
||||
"""Log a warning message
|
||||
|
||||
Args:
|
||||
msg: the message to log
|
||||
"""
|
||||
print(msg)
|
||||
|
|
|
@ -13,7 +13,6 @@ from theflow.utils.modules import import_dotted_string
|
|||
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
||||
|
||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
from .ui import FileIndexPage, FileSelector
|
||||
|
||||
|
||||
class FileIndex(BaseIndex):
|
||||
|
@ -77,9 +76,15 @@ class FileIndex(BaseIndex):
|
|||
|
||||
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
|
||||
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
|
||||
self._selector_ui_cls: Type
|
||||
self._selector_ui: Any = None
|
||||
self._index_ui_cls: Type
|
||||
self._index_ui: Any = None
|
||||
|
||||
self._setup_indexing_cls()
|
||||
self._setup_retriever_cls()
|
||||
self._setup_file_index_ui_cls()
|
||||
self._setup_file_selector_ui_cls()
|
||||
|
||||
self._default_settings: dict[str, dict] = {}
|
||||
self._setting_mappings: dict[str, dict] = {}
|
||||
|
@ -91,14 +96,14 @@ class FileIndex(BaseIndex):
|
|||
|
||||
The indexing class will is retrieved from the following order. Stop at the
|
||||
first order found:
|
||||
- `FILE_INDEX_PIPELINE` in self._config
|
||||
- `FILE_INDEX_PIPELINE` in self.config
|
||||
- `FILE_INDEX_{id}_PIPELINE` in the flowsettings
|
||||
- `FILE_INDEX_PIPELINE` in the flowsettings
|
||||
- The default .pipelines.IndexDocumentPipeline
|
||||
"""
|
||||
if "FILE_INDEX_PIPELINE" in self._config:
|
||||
if "FILE_INDEX_PIPELINE" in self.config:
|
||||
self._indexing_pipeline_cls = import_dotted_string(
|
||||
self._config["FILE_INDEX_PIPELINE"], safe=False
|
||||
self.config["FILE_INDEX_PIPELINE"], safe=False
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -125,15 +130,15 @@ class FileIndex(BaseIndex):
|
|||
|
||||
The retriever classes will is retrieved from the following order. Stop at the
|
||||
first order found:
|
||||
- `FILE_INDEX_RETRIEVER_PIPELINES` in self._config
|
||||
- `FILE_INDEX_RETRIEVER_PIPELINES` in self.config
|
||||
- `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings
|
||||
- `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings
|
||||
- The default .pipelines.DocumentRetrievalPipeline
|
||||
"""
|
||||
if "FILE_INDEX_RETRIEVER_PIPELINES" in self._config:
|
||||
if "FILE_INDEX_RETRIEVER_PIPELINES" in self.config:
|
||||
self._retriever_pipeline_cls = [
|
||||
import_dotted_string(each, safe=False)
|
||||
for each in self._config["FILE_INDEX_RETRIEVER_PIPELINES"]
|
||||
for each in self.config["FILE_INDEX_RETRIEVER_PIPELINES"]
|
||||
]
|
||||
return
|
||||
|
||||
|
@ -157,6 +162,76 @@ class FileIndex(BaseIndex):
|
|||
|
||||
self._retriever_pipeline_cls = [DocumentRetrievalPipeline]
|
||||
|
||||
def _setup_file_selector_ui_cls(self):
|
||||
"""Retrieve the file selector UI for the file index
|
||||
|
||||
There can be multiple retriever classes.
|
||||
|
||||
The retriever classes will is retrieved from the following order. Stop at the
|
||||
first order found:
|
||||
- `FILE_INDEX_SELECTOR_UI` in self.config
|
||||
- `FILE_INDEX_{id}_SELECTOR_UI` in the flowsettings
|
||||
- `FILE_INDEX_SELECTOR_UI` in the flowsettings
|
||||
- The default .ui.FileSelector
|
||||
"""
|
||||
if "FILE_INDEX_SELECTOR_UI" in self.config:
|
||||
self._selector_ui_cls = import_dotted_string(
|
||||
self.config["FILE_INDEX_SELECTOR_UI"], safe=False
|
||||
)
|
||||
return
|
||||
|
||||
if hasattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"):
|
||||
self._selector_ui_cls = import_dotted_string(
|
||||
getattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"),
|
||||
safe=False,
|
||||
)
|
||||
return
|
||||
|
||||
if hasattr(flowsettings, "FILE_INDEX_SELECTOR_UI"):
|
||||
self._selector_ui_cls = import_dotted_string(
|
||||
getattr(flowsettings, "FILE_INDEX_SELECTOR_UI"), safe=False
|
||||
)
|
||||
return
|
||||
|
||||
from .ui import FileSelector
|
||||
|
||||
self._selector_ui_cls = FileSelector
|
||||
|
||||
def _setup_file_index_ui_cls(self):
|
||||
"""Retrieve the Index UI class
|
||||
|
||||
There can be multiple retriever classes.
|
||||
|
||||
The retriever classes will is retrieved from the following order. Stop at the
|
||||
first order found:
|
||||
- `FILE_INDEX_UI` in self.config
|
||||
- `FILE_INDEX_{id}_UI` in the flowsettings
|
||||
- `FILE_INDEX_UI` in the flowsettings
|
||||
- The default .ui.FileIndexPage
|
||||
"""
|
||||
if "FILE_INDEX_UI" in self.config:
|
||||
self._index_ui_cls = import_dotted_string(
|
||||
self.config["FILE_INDEX_UI"], safe=False
|
||||
)
|
||||
return
|
||||
|
||||
if hasattr(flowsettings, f"FILE_INDEX_{self.id}_UI"):
|
||||
self._index_ui_cls = import_dotted_string(
|
||||
getattr(flowsettings, f"FILE_INDEX_{self.id}_UI"),
|
||||
safe=False,
|
||||
)
|
||||
return
|
||||
|
||||
if hasattr(flowsettings, "FILE_INDEX_UI"):
|
||||
self._index_ui_cls = import_dotted_string(
|
||||
getattr(flowsettings, "FILE_INDEX_UI"), safe=False
|
||||
)
|
||||
return
|
||||
|
||||
from .ui import FileIndexPage
|
||||
|
||||
self._index_ui_cls = FileIndexPage
|
||||
|
||||
def on_create(self):
|
||||
"""Create the index for the first time
|
||||
|
||||
|
@ -165,6 +240,13 @@ class FileIndex(BaseIndex):
|
|||
2. Create the vectorstore
|
||||
3. Create the docstore
|
||||
"""
|
||||
file_types_str = self.config.get(
|
||||
"supported_file_types",
|
||||
self.get_admin_settings()["supported_file_types"]["value"],
|
||||
)
|
||||
file_types = [each.strip() for each in file_types_str.split(",")]
|
||||
self.config["supported_file_types"] = file_types
|
||||
|
||||
self._resources["Source"].metadata.create_all(engine) # type: ignore
|
||||
self._resources["Index"].metadata.create_all(engine) # type: ignore
|
||||
self._fs_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -180,10 +262,14 @@ class FileIndex(BaseIndex):
|
|||
shutil.rmtree(self._fs_path)
|
||||
|
||||
def get_selector_component_ui(self):
|
||||
return FileSelector(self._app, self)
|
||||
if self._selector_ui is None:
|
||||
self._selector_ui = self._selector_ui_cls(self._app, self)
|
||||
return self._selector_ui
|
||||
|
||||
def get_index_page_ui(self):
|
||||
return FileIndexPage(self._app, self)
|
||||
if self._index_ui is None:
|
||||
self._index_ui = self._index_ui_cls(self._app, self)
|
||||
return self._index_ui
|
||||
|
||||
def get_user_settings(self):
|
||||
if self._default_settings:
|
||||
|
@ -210,7 +296,31 @@ class FileIndex(BaseIndex):
|
|||
"value": embedding_default,
|
||||
"component": "dropdown",
|
||||
"choices": embedding_choices,
|
||||
}
|
||||
},
|
||||
"supported_file_types": {
|
||||
"name": "Supported file types",
|
||||
"value": (
|
||||
"image, .pdf, .txt, .csv, .xlsx, .doc, .docx, .pptx, .html, .zip"
|
||||
),
|
||||
"component": "text",
|
||||
},
|
||||
"max_file_size": {
|
||||
"name": "Max file size (MB) - set 0 to disable",
|
||||
"value": 1000,
|
||||
"component": "number",
|
||||
},
|
||||
"max_number_of_files": {
|
||||
"name": "Max number of files that can be indexed - set 0 to disable",
|
||||
"value": 0,
|
||||
"component": "number",
|
||||
},
|
||||
"max_number_of_text_length": {
|
||||
"name": (
|
||||
"Max amount of characters that can be indexed - set 0 to disable"
|
||||
),
|
||||
"value": 0,
|
||||
"component": "number",
|
||||
},
|
||||
}
|
||||
|
||||
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
|
||||
|
@ -224,14 +334,15 @@ class FileIndex(BaseIndex):
|
|||
else:
|
||||
stripped_settings[key] = value
|
||||
|
||||
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self._config)
|
||||
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
|
||||
obj.set_resources(resources=self._resources)
|
||||
|
||||
return obj
|
||||
|
||||
def get_retriever_pipelines(
|
||||
self, settings: dict, selected: Optional[list] = None
|
||||
self, settings: dict, selected: Any = None
|
||||
) -> list["BaseFileIndexRetriever"]:
|
||||
# retrieval settings
|
||||
prefix = f"index.options.{self.id}."
|
||||
stripped_settings = {}
|
||||
for key, value in settings.items():
|
||||
|
@ -240,9 +351,12 @@ class FileIndex(BaseIndex):
|
|||
else:
|
||||
stripped_settings[key] = value
|
||||
|
||||
# transform selected id
|
||||
selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)
|
||||
|
||||
retrievers = []
|
||||
for cls in self._retriever_pipeline_cls:
|
||||
obj = cls.get_pipeline(stripped_settings, self._config, selected)
|
||||
obj = cls.get_pipeline(stripped_settings, self.config, selected_ids)
|
||||
if obj is None:
|
||||
continue
|
||||
obj.set_resources(self._resources)
|
||||
|
|
|
@ -9,6 +9,7 @@ from hashlib import sha256
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import gradio as gr
|
||||
from ktem.components import embeddings, filestorage_path
|
||||
from ktem.db.models import engine
|
||||
from llama_index.vector_stores import (
|
||||
|
@ -18,7 +19,7 @@ from llama_index.vector_stores import (
|
|||
MetadataFilters,
|
||||
)
|
||||
from llama_index.vector_stores.types import VectorStoreQueryMode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
from theflow.settings import settings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
@ -279,6 +280,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
|||
to_index: list[str] = []
|
||||
file_to_hash: dict[str, str] = {}
|
||||
errors = []
|
||||
to_update = []
|
||||
|
||||
for file_path in file_paths:
|
||||
abs_path = str(Path(file_path).resolve())
|
||||
|
@ -291,16 +293,26 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
|||
statement = select(Source).where(Source.name == Path(abs_path).name)
|
||||
item = session.execute(statement).first()
|
||||
|
||||
if item and not reindex:
|
||||
errors.append(Path(abs_path).name)
|
||||
continue
|
||||
if item:
|
||||
if not reindex:
|
||||
errors.append(Path(abs_path).name)
|
||||
continue
|
||||
else:
|
||||
to_update.append(Path(abs_path).name)
|
||||
|
||||
to_index.append(abs_path)
|
||||
|
||||
if errors:
|
||||
error_files = ", ".join(errors)
|
||||
if len(error_files) > 100:
|
||||
error_files = error_files[:80] + "..."
|
||||
print(
|
||||
"Files already exist. Please rename/remove them or enable reindex.\n"
|
||||
f"{errors}"
|
||||
"Skip these files already exist. Please rename/remove them or "
|
||||
f"enable reindex:\n{errors}"
|
||||
)
|
||||
self.warning(
|
||||
"Skip these files already exist. Please rename/remove them or "
|
||||
f"enable reindex:\n{error_files}"
|
||||
)
|
||||
|
||||
if not to_index:
|
||||
|
@ -310,9 +322,19 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
|||
for path in to_index:
|
||||
shutil.copy(path, filestorage_path / file_to_hash[path])
|
||||
|
||||
# prepare record info
|
||||
# extract the file & prepare record info
|
||||
file_to_source: dict = {}
|
||||
extraction_errors = []
|
||||
nodes = []
|
||||
for file_path, file_hash in file_to_hash.items():
|
||||
if str(Path(file_path).resolve()) not in to_index:
|
||||
continue
|
||||
|
||||
extraction_result = self.file_ingestor(file_path)
|
||||
if not extraction_result:
|
||||
extraction_errors.append(Path(file_path).name)
|
||||
continue
|
||||
nodes.extend(extraction_result)
|
||||
source = Source(
|
||||
name=Path(file_path).name,
|
||||
path=file_hash,
|
||||
|
@ -320,9 +342,23 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
|||
)
|
||||
file_to_source[file_path] = source
|
||||
|
||||
# extract the files
|
||||
nodes = self.file_ingestor(to_index)
|
||||
print("Extracted", len(to_index), "files into", len(nodes), "nodes")
|
||||
if extraction_errors:
|
||||
msg = "Failed to extract these files: {}".format(
|
||||
", ".join(extraction_errors)
|
||||
)
|
||||
print(msg)
|
||||
self.warning(msg)
|
||||
|
||||
if not nodes:
|
||||
return [], []
|
||||
|
||||
print(
|
||||
"Extracted",
|
||||
len(to_index) - len(extraction_errors),
|
||||
"files into",
|
||||
len(nodes),
|
||||
"nodes",
|
||||
)
|
||||
|
||||
# index the files
|
||||
print("Indexing the files into vector store")
|
||||
|
@ -332,7 +368,11 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
|||
# persist to the index
|
||||
print("Persisting the vector and the document into index")
|
||||
file_ids = []
|
||||
to_update = list(set(to_update))
|
||||
with Session(engine) as session:
|
||||
if to_update:
|
||||
session.execute(delete(Source).where(Source.name.in_(to_update)))
|
||||
|
||||
for source in file_to_source.values():
|
||||
session.add(source)
|
||||
session.commit()
|
||||
|
@ -404,3 +444,6 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
|||
super().set_resources(resources)
|
||||
self.indexing_vector_pipeline.vector_store = self._VS
|
||||
self.indexing_vector_pipeline.doc_store = self._DS
|
||||
|
||||
def warning(self, msg):
|
||||
gr.Warning(msg)
|
||||
|
|
|
@ -1,29 +1,48 @@
|
|||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
from gradio.data_classes import FileData
|
||||
from gradio.utils import NamedString
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.engine import engine
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class File(gr.File):
|
||||
"""Subclass from gr.File to maintain the original filename
|
||||
|
||||
The issue happens when user uploads file with name like: !@#$%%^&*().pdf
|
||||
"""
|
||||
|
||||
def _process_single_file(self, f: FileData) -> NamedString | bytes:
|
||||
file_name = f.path
|
||||
if self.type == "filepath":
|
||||
if f.orig_name and Path(file_name).name != f.orig_name:
|
||||
file_name = str(Path(file_name).parent / f.orig_name)
|
||||
os.rename(f.path, file_name)
|
||||
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
|
||||
file.name = file_name
|
||||
return NamedString(file_name)
|
||||
elif self.type == "binary":
|
||||
with open(file_name, "rb") as file_data:
|
||||
return file_data.read()
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown type: "
|
||||
+ str(type)
|
||||
+ ". Please choose from: 'filepath', 'binary'."
|
||||
)
|
||||
|
||||
|
||||
class DirectoryUpload(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self._supported_file_types = [
|
||||
"image",
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".xlsx",
|
||||
".doc",
|
||||
".docx",
|
||||
".pptx",
|
||||
".html",
|
||||
".zip",
|
||||
]
|
||||
def __init__(self, app, index):
|
||||
super().__init__(app)
|
||||
self._index = index
|
||||
self._supported_file_types = self._index.config.get("supported_file_types", [])
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
|
@ -50,18 +69,7 @@ class FileIndexPage(BasePage):
|
|||
def __init__(self, app, index):
|
||||
super().__init__(app)
|
||||
self._index = index
|
||||
self._supported_file_types = [
|
||||
"image",
|
||||
".pdf",
|
||||
".txt",
|
||||
".csv",
|
||||
".xlsx",
|
||||
".doc",
|
||||
".docx",
|
||||
".pptx",
|
||||
".html",
|
||||
".zip",
|
||||
]
|
||||
self._supported_file_types = self._index.config.get("supported_file_types", [])
|
||||
self.selected_panel_false = "Selected file: (please select above)"
|
||||
self.selected_panel_true = "Selected file: {name}"
|
||||
# TODO: on_building_ui is not correctly named if it's always called in
|
||||
|
@ -69,13 +77,32 @@ class FileIndexPage(BasePage):
|
|||
self.public_events = [f"onFileIndex{index.id}Changed"]
|
||||
self.on_building_ui()
|
||||
|
||||
def upload_instruction(self) -> str:
|
||||
msgs = []
|
||||
if self._supported_file_types:
|
||||
msgs.append(
|
||||
f"- Supported file types: {', '.join(self._supported_file_types)}"
|
||||
)
|
||||
|
||||
if max_file_size := self._index.config.get("max_file_size", 0):
|
||||
msgs.append(f"- Maximum file size: {max_file_size} MB")
|
||||
|
||||
if max_number_of_files := self._index.config.get("max_number_of_files", 0):
|
||||
msgs.append(f"- The index can have maximum {max_number_of_files} files")
|
||||
|
||||
if msgs:
|
||||
return "\n".join(msgs)
|
||||
|
||||
return ""
|
||||
|
||||
def on_building_ui(self):
|
||||
"""Build the UI of the app"""
|
||||
with gr.Accordion(label="File upload", open=False):
|
||||
gr.Markdown(
|
||||
f"Supported file types: {', '.join(self._supported_file_types)}",
|
||||
)
|
||||
self.files = gr.File(
|
||||
with gr.Accordion(label="File upload", open=True) as self.upload:
|
||||
msg = self.upload_instruction()
|
||||
if msg:
|
||||
gr.Markdown(msg)
|
||||
|
||||
self.files = File(
|
||||
file_types=self._supported_file_types,
|
||||
file_count="multiple",
|
||||
container=False,
|
||||
|
@ -98,18 +125,20 @@ class FileIndexPage(BasePage):
|
|||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row() as self.selection_info:
|
||||
self.selected_file_id = gr.State(value=None)
|
||||
self.selected_panel = gr.Markdown(self.selected_panel_false)
|
||||
self.deselect_button = gr.Button("Deselect", visible=False)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Row() as self.tools:
|
||||
with gr.Column():
|
||||
self.view_button = gr.Button("View Text (WIP)")
|
||||
with gr.Column():
|
||||
self.delete_button = gr.Button("Delete")
|
||||
with gr.Row():
|
||||
self.delete_yes = gr.Button("Confirm Delete", visible=False)
|
||||
self.delete_yes = gr.Button(
|
||||
"Confirm Delete", variant="primary", visible=False
|
||||
)
|
||||
self.delete_no = gr.Button("Cancel", visible=False)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
|
@ -242,10 +271,12 @@ class FileIndexPage(BasePage):
|
|||
self._app.settings_state,
|
||||
],
|
||||
outputs=[self.file_output],
|
||||
concurrency_limit=20,
|
||||
).then(
|
||||
fn=self.list_file,
|
||||
inputs=None,
|
||||
outputs=[self.file_list_state, self.file_list],
|
||||
concurrency_limit=20,
|
||||
)
|
||||
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
|
||||
onUploaded = onUploaded.then(**event)
|
||||
|
@ -274,6 +305,15 @@ class FileIndexPage(BasePage):
|
|||
selected_files: the list of files already selected
|
||||
settings: the settings of the app
|
||||
"""
|
||||
if not files:
|
||||
gr.Info("No uploaded file")
|
||||
return gr.update()
|
||||
|
||||
errors = self.validate(files)
|
||||
if errors:
|
||||
gr.Warning(", ".join(errors))
|
||||
return gr.update()
|
||||
|
||||
gr.Info(f"Start indexing {len(files)} files...")
|
||||
|
||||
# get the pipeline
|
||||
|
@ -409,6 +449,35 @@ class FileIndexPage(BasePage):
|
|||
name=list_files["name"][ev.index[0]]
|
||||
)
|
||||
|
||||
def validate(self, files: list[str]):
|
||||
"""Validate if the files are valid"""
|
||||
paths = [Path(file) for file in files]
|
||||
errors = []
|
||||
if max_file_size := self._index.config.get("max_file_size", 0):
|
||||
errors_max_size = []
|
||||
for path in paths:
|
||||
if path.stat().st_size > max_file_size * 1e6:
|
||||
errors_max_size.append(path.name)
|
||||
if errors_max_size:
|
||||
str_errors = ", ".join(errors_max_size)
|
||||
if len(str_errors) > 60:
|
||||
str_errors = str_errors[:55] + "..."
|
||||
errors.append(
|
||||
f"Maximum file size ({max_file_size} MB) exceeded: {str_errors}"
|
||||
)
|
||||
|
||||
if max_number_of_files := self._index.config.get("max_number_of_files", 0):
|
||||
with Session(engine) as session:
|
||||
current_num_files = session.query(
|
||||
self._index._db_tables["Source"].id
|
||||
).count()
|
||||
if len(paths) + current_num_files > max_number_of_files:
|
||||
errors.append(
|
||||
f"Maximum number of files ({max_number_of_files}) will be exceeded"
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
class FileSelector(BasePage):
|
||||
"""File selector UI in the Chat page"""
|
||||
|
@ -430,6 +499,9 @@ class FileSelector(BasePage):
|
|||
def as_gradio_component(self):
|
||||
return self.selector
|
||||
|
||||
def get_selected_ids(self, selected):
|
||||
return selected
|
||||
|
||||
def load_files(self, selected_files):
|
||||
options = []
|
||||
available_ids = []
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Type
|
||||
from typing import Optional, Type
|
||||
|
||||
from ktem.db.models import engine
|
||||
from sqlmodel import Session, select
|
||||
|
@ -49,15 +49,19 @@ class IndexManager:
|
|||
Returns:
|
||||
BaseIndex: the index object
|
||||
"""
|
||||
index_cls = import_dotted_string(index_type, safe=False)
|
||||
index = index_cls(app=self._app, id=id, name=name, config=config)
|
||||
index.on_create()
|
||||
|
||||
with Session(engine) as session:
|
||||
index_entry = Index(id=id, name=name, config=config, index_type=index_type)
|
||||
index_entry = Index(
|
||||
id=index.id, name=index.name, config=index.config, index_type=index_type
|
||||
)
|
||||
session.add(index_entry)
|
||||
session.commit()
|
||||
session.refresh(index_entry)
|
||||
|
||||
index_cls = import_dotted_string(index_type, safe=False)
|
||||
index = index_cls(app=self._app, id=id, name=name, config=config)
|
||||
index.on_create()
|
||||
index.id = index_entry.id
|
||||
|
||||
return index
|
||||
|
||||
|
@ -77,7 +81,7 @@ class IndexManager:
|
|||
self._indices.append(index)
|
||||
return index
|
||||
|
||||
def exists(self, id: int) -> bool:
|
||||
def exists(self, id: Optional[int] = None, name: Optional[str] = None) -> bool:
|
||||
"""Check if the index exists
|
||||
|
||||
Args:
|
||||
|
@ -86,9 +90,19 @@ class IndexManager:
|
|||
Returns:
|
||||
bool: True if the index exists, False otherwise
|
||||
"""
|
||||
with Session(engine) as session:
|
||||
index = session.get(Index, id)
|
||||
return index is not None
|
||||
if id:
|
||||
with Session(engine) as session:
|
||||
index = session.get(Index, id)
|
||||
return index is not None
|
||||
|
||||
if name:
|
||||
with Session(engine) as session:
|
||||
index = session.exec(
|
||||
select(Index).where(Index.name == name)
|
||||
).one_or_none()
|
||||
return index is not None
|
||||
|
||||
return False
|
||||
|
||||
def on_application_startup(self):
|
||||
"""This method is called by the base application when the application starts
|
||||
|
|
|
@ -27,7 +27,7 @@ class App(BaseApp):
|
|||
if self.f_user_management:
|
||||
from ktem.pages.login import LoginPage
|
||||
|
||||
with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]:
|
||||
with gr.Tab("Welcome", elem_id="login-tab") as self._tabs["login-tab"]:
|
||||
self.login_page = LoginPage(self)
|
||||
|
||||
with gr.Tab(
|
||||
|
@ -62,6 +62,9 @@ class App(BaseApp):
|
|||
|
||||
def on_subscribe_public_events(self):
|
||||
if self.f_user_management:
|
||||
from ktem.db.engine import engine
|
||||
from ktem.db.models import User
|
||||
from sqlmodel import Session, select
|
||||
|
||||
def signed_in_out(user_id):
|
||||
if not user_id:
|
||||
|
@ -73,14 +76,31 @@ class App(BaseApp):
|
|||
)
|
||||
for k in self._tabs.keys()
|
||||
)
|
||||
return list(
|
||||
(
|
||||
gr.update(visible=True)
|
||||
if k != "login-tab"
|
||||
else gr.update(visible=False)
|
||||
)
|
||||
for k in self._tabs.keys()
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
user = session.exec(select(User).where(User.id == user_id)).first()
|
||||
if user is None:
|
||||
return list(
|
||||
(
|
||||
gr.update(visible=True)
|
||||
if k == "login-tab"
|
||||
else gr.update(visible=False)
|
||||
)
|
||||
for k in self._tabs.keys()
|
||||
)
|
||||
|
||||
is_admin = user.admin
|
||||
|
||||
tabs_update = []
|
||||
for k in self._tabs.keys():
|
||||
if k == "login-tab":
|
||||
tabs_update.append(gr.update(visible=False))
|
||||
elif k == "admin-tab":
|
||||
tabs_update.append(gr.update(visible=is_admin))
|
||||
else:
|
||||
tabs_update.append(gr.update(visible=True))
|
||||
|
||||
return tabs_update
|
||||
|
||||
self.subscribe_event(
|
||||
name="onSignIn",
|
||||
|
|
|
@ -40,7 +40,7 @@ def validate_username(usn):
|
|||
if len(usn) > 32:
|
||||
errors.append("Username must be at most 32 characters long")
|
||||
|
||||
if not usn.strip("_").isalnum():
|
||||
if not usn.replace("_", "").isalnum():
|
||||
errors.append(
|
||||
"Username must contain only alphanumeric characters and underscores"
|
||||
)
|
||||
|
@ -97,8 +97,6 @@ def validate_password(pwd, pwd_cnf):
|
|||
class UserManagement(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.selected_panel_false = "Selected user: (please select above)"
|
||||
self.selected_panel_true = "Selected user: {name}"
|
||||
|
||||
self.on_building_ui()
|
||||
if hasattr(flowsettings, "KH_FEATURE_USER_MANAGEMENT_ADMIN") and hasattr(
|
||||
|
@ -126,7 +124,38 @@ class UserManagement(BasePage):
|
|||
gr.Info(f'User "{usn}" created successfully')
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Accordion(label="Create user", open=False):
|
||||
with gr.Tab(label="User list"):
|
||||
self.state_user_list = gr.State(value=None)
|
||||
self.user_list = gr.DataFrame(
|
||||
headers=["id", "name", "admin"],
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Group(visible=False) as self._selected_panel:
|
||||
self.selected_user_id = gr.Number(value=-1, visible=False)
|
||||
self.usn_edit = gr.Textbox(label="Username")
|
||||
with gr.Row():
|
||||
self.pwd_edit = gr.Textbox(label="Change password", type="password")
|
||||
self.pwd_cnf_edit = gr.Textbox(
|
||||
label="Confirm change password",
|
||||
type="password",
|
||||
)
|
||||
self.admin_edit = gr.Checkbox(label="Admin")
|
||||
|
||||
with gr.Row() as self._selected_panel_btn:
|
||||
with gr.Column():
|
||||
self.btn_edit_save = gr.Button("Save")
|
||||
with gr.Column():
|
||||
self.btn_delete = gr.Button("Delete")
|
||||
with gr.Row():
|
||||
self.btn_delete_yes = gr.Button(
|
||||
"Confirm delete", variant="primary", visible=False
|
||||
)
|
||||
self.btn_delete_no = gr.Button("Cancel", visible=False)
|
||||
with gr.Column():
|
||||
self.btn_close = gr.Button("Close")
|
||||
|
||||
with gr.Tab(label="Create user"):
|
||||
self.usn_new = gr.Textbox(label="Username", interactive=True)
|
||||
self.pwd_new = gr.Textbox(
|
||||
label="Password", type="password", interactive=True
|
||||
|
@ -139,52 +168,28 @@ class UserManagement(BasePage):
|
|||
gr.Markdown(PASSWORD_RULE)
|
||||
self.btn_new = gr.Button("Create user")
|
||||
|
||||
gr.Markdown("## User list")
|
||||
self.btn_list_user = gr.Button("Refresh user list")
|
||||
self.state_user_list = gr.State(value=None)
|
||||
self.user_list = gr.DataFrame(
|
||||
headers=["id", "name", "admin"],
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.selected_user_id = gr.State(value=None)
|
||||
self.selected_panel = gr.Markdown(self.selected_panel_false)
|
||||
self.deselect_button = gr.Button("Deselect", visible=False)
|
||||
|
||||
with gr.Group():
|
||||
self.btn_delete = gr.Button("Delete user")
|
||||
with gr.Row():
|
||||
self.btn_delete_yes = gr.Button("Confirm", visible=False)
|
||||
self.btn_delete_no = gr.Button("Cancel", visible=False)
|
||||
|
||||
gr.Markdown("## User details")
|
||||
self.usn_edit = gr.Textbox(label="Username")
|
||||
self.pwd_edit = gr.Textbox(label="Password", type="password")
|
||||
self.pwd_cnf_edit = gr.Textbox(label="Confirm password", type="password")
|
||||
self.admin_edit = gr.Checkbox(label="Admin")
|
||||
self.btn_edit_save = gr.Button("Save")
|
||||
|
||||
def on_register_events(self):
|
||||
self.btn_new.click(
|
||||
self.create_user,
|
||||
inputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
|
||||
outputs=None,
|
||||
)
|
||||
self.btn_list_user.click(
|
||||
self.list_users, inputs=None, outputs=[self.state_user_list, self.user_list]
|
||||
outputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
|
||||
).then(
|
||||
self.list_users,
|
||||
inputs=self._app.user_id,
|
||||
outputs=[self.state_user_list, self.user_list],
|
||||
)
|
||||
self.user_list.select(
|
||||
self.select_user,
|
||||
inputs=self.user_list,
|
||||
outputs=[self.selected_user_id, self.selected_panel],
|
||||
outputs=[self.selected_user_id],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.selected_panel.change(
|
||||
self.selected_user_id.change(
|
||||
self.on_selected_user_change,
|
||||
inputs=[self.selected_user_id],
|
||||
outputs=[
|
||||
self.deselect_button,
|
||||
self._selected_panel,
|
||||
self._selected_panel_btn,
|
||||
# delete section
|
||||
self.btn_delete,
|
||||
self.btn_delete_yes,
|
||||
|
@ -197,12 +202,6 @@ class UserManagement(BasePage):
|
|||
],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.deselect_button.click(
|
||||
lambda: (None, self.selected_panel_false),
|
||||
inputs=None,
|
||||
outputs=[self.selected_user_id, self.selected_panel],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_delete.click(
|
||||
self.on_btn_delete_click,
|
||||
inputs=[self.selected_user_id],
|
||||
|
@ -211,9 +210,13 @@ class UserManagement(BasePage):
|
|||
)
|
||||
self.btn_delete_yes.click(
|
||||
self.delete_user,
|
||||
inputs=[self.selected_user_id],
|
||||
outputs=[self.selected_user_id, self.selected_panel],
|
||||
inputs=[self._app.user_id, self.selected_user_id],
|
||||
outputs=[self.selected_user_id],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_users,
|
||||
inputs=self._app.user_id,
|
||||
outputs=[self.state_user_list, self.user_list],
|
||||
)
|
||||
self.btn_delete_no.click(
|
||||
lambda: (
|
||||
|
@ -234,21 +237,53 @@ class UserManagement(BasePage):
|
|||
self.pwd_cnf_edit,
|
||||
self.admin_edit,
|
||||
],
|
||||
outputs=None,
|
||||
outputs=[self.pwd_edit, self.pwd_cnf_edit],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_users,
|
||||
inputs=self._app.user_id,
|
||||
outputs=[self.state_user_list, self.user_list],
|
||||
)
|
||||
self.btn_close.click(
|
||||
lambda: -1,
|
||||
outputs=[self.selected_user_id],
|
||||
)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.list_users,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.state_user_list, self.user_list],
|
||||
},
|
||||
)
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": lambda: ("", "", "", None, None, -1),
|
||||
"outputs": [
|
||||
self.usn_new,
|
||||
self.pwd_new,
|
||||
self.pwd_cnf_new,
|
||||
self.state_user_list,
|
||||
self.user_list,
|
||||
self.selected_user_id,
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def create_user(self, usn, pwd, pwd_cnf):
|
||||
errors = validate_username(usn)
|
||||
if errors:
|
||||
gr.Warning(errors)
|
||||
return
|
||||
return usn, pwd, pwd_cnf
|
||||
|
||||
errors = validate_password(pwd, pwd_cnf)
|
||||
print(errors)
|
||||
if errors:
|
||||
gr.Warning(errors)
|
||||
return
|
||||
return usn, pwd, pwd_cnf
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(User.username_lower == usn.lower())
|
||||
|
@ -265,8 +300,22 @@ class UserManagement(BasePage):
|
|||
session.commit()
|
||||
gr.Info(f'User "{usn}" created successfully')
|
||||
|
||||
def list_users(self):
|
||||
return "", "", ""
|
||||
|
||||
def list_users(self, user_id):
|
||||
if user_id is None:
|
||||
return [], pd.DataFrame.from_records(
|
||||
[{"id": "-", "username": "-", "admin": "-"}]
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(User.id == user_id)
|
||||
user = session.exec(statement).one()
|
||||
if not user.admin:
|
||||
return [], pd.DataFrame.from_records(
|
||||
[{"id": "-", "username": "-", "admin": "-"}]
|
||||
)
|
||||
|
||||
statement = select(User)
|
||||
results = [
|
||||
{"id": user.id, "username": user.username, "admin": user.admin}
|
||||
|
@ -284,18 +333,17 @@ class UserManagement(BasePage):
|
|||
def select_user(self, user_list, ev: gr.SelectData):
|
||||
if ev.value == "-" and ev.index[0] == 0:
|
||||
gr.Info("No user is loaded. Please refresh the user list")
|
||||
return None, self.selected_panel_false
|
||||
return -1
|
||||
|
||||
if not ev.selected:
|
||||
return None, self.selected_panel_false
|
||||
return -1
|
||||
|
||||
return user_list["id"][ev.index[0]], self.selected_panel_true.format(
|
||||
name=user_list["username"][ev.index[0]]
|
||||
)
|
||||
return user_list["id"][ev.index[0]]
|
||||
|
||||
def on_selected_user_change(self, selected_user_id):
|
||||
if selected_user_id is None:
|
||||
deselect_button = gr.update(visible=False)
|
||||
if selected_user_id == -1:
|
||||
_selected_panel = gr.update(visible=False)
|
||||
_selected_panel_btn = gr.update(visible=False)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
|
@ -304,7 +352,8 @@ class UserManagement(BasePage):
|
|||
pwd_cnf_edit = gr.update(value="")
|
||||
admin_edit = gr.update(value=False)
|
||||
else:
|
||||
deselect_button = gr.update(visible=True)
|
||||
_selected_panel = gr.update(visible=True)
|
||||
_selected_panel_btn = gr.update(visible=True)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
|
@ -319,7 +368,8 @@ class UserManagement(BasePage):
|
|||
admin_edit = gr.update(value=user.admin)
|
||||
|
||||
return (
|
||||
deselect_button,
|
||||
_selected_panel,
|
||||
_selected_panel_btn,
|
||||
btn_delete,
|
||||
btn_delete_yes,
|
||||
btn_delete_no,
|
||||
|
@ -344,17 +394,16 @@ class UserManagement(BasePage):
|
|||
return btn_delete, btn_delete_yes, btn_delete_no
|
||||
|
||||
def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin):
|
||||
if usn:
|
||||
errors = validate_username(usn)
|
||||
if errors:
|
||||
gr.Warning(errors)
|
||||
return
|
||||
errors = validate_username(usn)
|
||||
if errors:
|
||||
gr.Warning(errors)
|
||||
return pwd, pwd_cnf
|
||||
|
||||
if pwd:
|
||||
errors = validate_password(pwd, pwd_cnf)
|
||||
if errors:
|
||||
gr.Warning(errors)
|
||||
return
|
||||
return pwd, pwd_cnf
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(User.id == int(selected_user_id))
|
||||
|
@ -367,11 +416,17 @@ class UserManagement(BasePage):
|
|||
session.commit()
|
||||
gr.Info(f'User "{usn}" updated successfully')
|
||||
|
||||
def delete_user(self, selected_user_id):
|
||||
return "", ""
|
||||
|
||||
def delete_user(self, current_user, selected_user_id):
|
||||
if current_user == selected_user_id:
|
||||
gr.Warning("You cannot delete yourself")
|
||||
return selected_user_id
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(User).where(User.id == int(selected_user_id))
|
||||
user = session.exec(statement).one()
|
||||
session.delete(user)
|
||||
session.commit()
|
||||
gr.Info(f'User "{user.username}" deleted successfully')
|
||||
return None, self.selected_panel_false
|
||||
return -1
|
||||
|
|
|
@ -7,8 +7,10 @@ from ktem.app import BasePage
|
|||
from ktem.components import reasonings
|
||||
from ktem.db.models import Conversation, engine
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from .chat_panel import ChatPanel
|
||||
from .chat_suggestion import ChatSuggestion
|
||||
from .common import STATE
|
||||
from .control import ConversationControl
|
||||
from .report import ReportIssue
|
||||
|
@ -26,24 +28,39 @@ class ChatPage(BasePage):
|
|||
with gr.Column(scale=1):
|
||||
self.chat_control = ConversationControl(self._app)
|
||||
|
||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||
self.chat_suggestion = ChatSuggestion(self._app)
|
||||
|
||||
for index in self._app.index_manager.indices:
|
||||
index.selector = -1
|
||||
index.selector = None
|
||||
index_ui = index.get_selector_component_ui()
|
||||
if not index_ui:
|
||||
# the index doesn't have a selector UI component
|
||||
continue
|
||||
|
||||
index_ui.unrender()
|
||||
index_ui.unrender() # need to rerender later within Accordion
|
||||
with gr.Accordion(label=f"{index.name} Index", open=False):
|
||||
index_ui.render()
|
||||
gr_index = index_ui.as_gradio_component()
|
||||
if gr_index:
|
||||
index.selector = len(self._indices_input)
|
||||
self._indices_input.append(gr_index)
|
||||
if isinstance(gr_index, list):
|
||||
index.selector = tuple(
|
||||
range(
|
||||
len(self._indices_input),
|
||||
len(self._indices_input) + len(gr_index),
|
||||
)
|
||||
)
|
||||
self._indices_input.extend(gr_index)
|
||||
else:
|
||||
index.selector = len(self._indices_input)
|
||||
self._indices_input.append(gr_index)
|
||||
setattr(self, f"_index_{index.id}", index_ui)
|
||||
|
||||
self.report_issue = ReportIssue(self._app)
|
||||
|
||||
with gr.Column(scale=6):
|
||||
self.chat_panel = ChatPanel(self._app)
|
||||
|
||||
with gr.Column(scale=3):
|
||||
with gr.Accordion(label="Information panel", open=True):
|
||||
self.info_panel = gr.HTML(elem_id="chat-info-panel")
|
||||
|
@ -54,11 +71,24 @@ class ChatPage(BasePage):
|
|||
self.chat_panel.text_input.submit,
|
||||
self.chat_panel.submit_btn.click,
|
||||
],
|
||||
fn=self.chat_panel.submit_msg,
|
||||
inputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
||||
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
|
||||
fn=self.submit_msg,
|
||||
inputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self._app.user_id,
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation_rn,
|
||||
],
|
||||
outputs=[
|
||||
self.chat_panel.text_input,
|
||||
self.chat_panel.chatbot,
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
],
|
||||
concurrency_limit=20,
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
).success(
|
||||
fn=self.chat_fn,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
|
@ -72,6 +102,7 @@ class ChatPage(BasePage):
|
|||
self.info_panel,
|
||||
self.chat_state,
|
||||
],
|
||||
concurrency_limit=20,
|
||||
show_progress="minimal",
|
||||
).then(
|
||||
fn=self.update_data_source,
|
||||
|
@ -82,6 +113,7 @@ class ChatPage(BasePage):
|
|||
]
|
||||
+ self._indices_input,
|
||||
outputs=None,
|
||||
concurrency_limit=20,
|
||||
)
|
||||
|
||||
self.chat_panel.regen_btn.click(
|
||||
|
@ -98,6 +130,7 @@ class ChatPage(BasePage):
|
|||
self.info_panel,
|
||||
self.chat_state,
|
||||
],
|
||||
concurrency_limit=20,
|
||||
show_progress="minimal",
|
||||
).then(
|
||||
fn=self.update_data_source,
|
||||
|
@ -108,6 +141,7 @@ class ChatPage(BasePage):
|
|||
]
|
||||
+ self._indices_input,
|
||||
outputs=None,
|
||||
concurrency_limit=20,
|
||||
)
|
||||
|
||||
self.chat_panel.chatbot.like(
|
||||
|
@ -116,7 +150,12 @@ class ChatPage(BasePage):
|
|||
outputs=None,
|
||||
)
|
||||
|
||||
self.chat_control.conversation.change(
|
||||
self.chat_control.btn_new.click(
|
||||
self.chat_control.new_conv,
|
||||
inputs=self._app.user_id,
|
||||
outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.chat_control.select_conv,
|
||||
inputs=[self.chat_control.conversation],
|
||||
outputs=[
|
||||
|
@ -124,12 +163,71 @@ class ChatPage(BasePage):
|
|||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
self.chat_control.btn_del.click(
|
||||
lambda id: self.toggle_delete(id),
|
||||
inputs=[self.chat_control.conversation_id],
|
||||
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
||||
)
|
||||
self.chat_control.btn_del_conf.click(
|
||||
self.chat_control.delete_conv,
|
||||
inputs=[self.chat_control.conversation_id, self._app.user_id],
|
||||
outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.chat_control.select_conv,
|
||||
inputs=[self.chat_control.conversation],
|
||||
outputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
]
|
||||
+ self._indices_input,
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
lambda: self.toggle_delete(""),
|
||||
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
||||
)
|
||||
self.chat_control.btn_del_cnl.click(
|
||||
lambda: self.toggle_delete(""),
|
||||
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
||||
)
|
||||
self.chat_control.conversation_rn_btn.click(
|
||||
self.chat_control.rename_conv,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation_rn,
|
||||
self._app.user_id,
|
||||
],
|
||||
outputs=[self.chat_control.conversation, self.chat_control.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
self.chat_control.conversation.select(
|
||||
self.chat_control.select_conv,
|
||||
inputs=[self.chat_control.conversation],
|
||||
outputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
]
|
||||
+ self._indices_input,
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
lambda: self.toggle_delete(""),
|
||||
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
||||
)
|
||||
|
||||
self.report_issue.report_btn.click(
|
||||
self.report_issue.report,
|
||||
inputs=[
|
||||
|
@ -140,11 +238,77 @@ class ChatPage(BasePage):
|
|||
self.chat_panel.chatbot,
|
||||
self._app.settings_state,
|
||||
self._app.user_id,
|
||||
self.info_panel,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
outputs=None,
|
||||
)
|
||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||
self.chat_suggestion.example.select(
|
||||
self.chat_suggestion.select_example,
|
||||
outputs=[self.chat_panel.text_input],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name):
|
||||
"""Submit a message to the chatbot"""
|
||||
if not chat_input:
|
||||
raise ValueError("Input is empty")
|
||||
|
||||
if not conv_id:
|
||||
id_, update = self.chat_control.new_conv(user_id)
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == id_)
|
||||
name = session.exec(statement).one().name
|
||||
new_conv_id = id_
|
||||
conv_update = update
|
||||
new_conv_name = name
|
||||
else:
|
||||
new_conv_id = conv_id
|
||||
conv_update = gr.update()
|
||||
new_conv_name = conv_name
|
||||
|
||||
return (
|
||||
"",
|
||||
chat_history + [(chat_input, None)],
|
||||
new_conv_id,
|
||||
conv_update,
|
||||
new_conv_name,
|
||||
)
|
||||
|
||||
def toggle_delete(self, conv_id):
|
||||
if conv_id:
|
||||
return gr.update(visible=False), gr.update(visible=True)
|
||||
else:
|
||||
return gr.update(visible=True), gr.update(visible=False)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
if self._app.f_user_management:
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.chat_control.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.chat_control.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": lambda: self.chat_control.select_conv(""),
|
||||
"outputs": [
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_control.conversation,
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
]
|
||||
+ self._indices_input,
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
def update_data_source(self, convo_id, messages, state, *selecteds):
|
||||
"""Update the data source"""
|
||||
|
@ -154,8 +318,12 @@ class ChatPage(BasePage):
|
|||
|
||||
selecteds_ = {}
|
||||
for index in self._app.index_manager.indices:
|
||||
if index.selector != -1:
|
||||
if index.selector is None:
|
||||
continue
|
||||
if isinstance(index.selector, int):
|
||||
selecteds_[str(index.id)] = selecteds[index.selector]
|
||||
else:
|
||||
selecteds_[str(index.id)] = [selecteds[i] for i in index.selector]
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == convo_id)
|
||||
|
@ -205,8 +373,11 @@ class ChatPage(BasePage):
|
|||
retrievers = []
|
||||
for index in self._app.index_manager.indices:
|
||||
index_selected = []
|
||||
if index.selector != -1:
|
||||
if isinstance(index.selector, int):
|
||||
index_selected = selecteds[index.selector]
|
||||
if isinstance(index.selector, tuple):
|
||||
for i in index.selector:
|
||||
index_selected.append(selecteds[i])
|
||||
iretrievers = index.get_retriever_pipelines(settings, index_selected)
|
||||
retrievers += iretrievers
|
||||
|
||||
|
@ -250,7 +421,10 @@ class ChatPage(BasePage):
|
|||
break
|
||||
|
||||
if "output" in response:
|
||||
text += response["output"]
|
||||
if response["output"] is None:
|
||||
text = ""
|
||||
else:
|
||||
text += response["output"]
|
||||
|
||||
if "evidence" in response:
|
||||
if response["evidence"] is None:
|
||||
|
|
26
libs/ktem/ktem/pages/chat/chat_suggestion.py
Normal file
26
libs/ktem/ktem/pages/chat/chat_suggestion.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
|
||||
class ChatSuggestion(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
chat_samples = getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", [])
|
||||
chat_samples = [[each] for each in chat_samples]
|
||||
with gr.Accordion(label="Chat Suggestion", open=False) as self.accordion:
|
||||
self.example = gr.DataFrame(
|
||||
value=chat_samples,
|
||||
headers=["Sample"],
|
||||
interactive=False,
|
||||
wrap=True,
|
||||
)
|
||||
|
||||
def as_gradio_component(self):
|
||||
return self.example
|
||||
|
||||
def select_example(self, ev: gr.SelectData):
|
||||
return ev.value
|
|
@ -10,6 +10,17 @@ from .common import STATE
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_conv_name_valid(name):
|
||||
"""Check if the conversation name is valid"""
|
||||
errors = []
|
||||
if len(name) == 0:
|
||||
errors.append("Name cannot be empty")
|
||||
elif len(name) > 40:
|
||||
errors.append("Name cannot be longer than 40 characters")
|
||||
|
||||
return "; ".join(errors)
|
||||
|
||||
|
||||
class ConversationControl(BasePage):
|
||||
"""Manage conversation"""
|
||||
|
||||
|
@ -28,9 +39,17 @@ class ConversationControl(BasePage):
|
|||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
self.conversation_new_btn = gr.Button(value="New", min_width=10)
|
||||
self.conversation_del_btn = gr.Button(value="Delete", min_width=10)
|
||||
with gr.Row() as self._new_delete:
|
||||
self.btn_new = gr.Button(value="New", min_width=10)
|
||||
self.btn_del = gr.Button(value="Delete", min_width=10)
|
||||
|
||||
with gr.Row(visible=False) as self._delete_confirm:
|
||||
self.btn_del_conf = gr.Button(
|
||||
value="Delete",
|
||||
variant="primary",
|
||||
min_width=10,
|
||||
)
|
||||
self.btn_del_cnl = gr.Button(value="Cancel", min_width=10)
|
||||
|
||||
with gr.Row():
|
||||
self.conversation_rn = gr.Text(
|
||||
|
@ -52,48 +71,6 @@ class ConversationControl(BasePage):
|
|||
# outputs=[current_state],
|
||||
# )
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
if self._app.f_user_management:
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": self.reload_conv,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.conversation],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
def on_register_events(self):
|
||||
self.conversation_new_btn.click(
|
||||
self.new_conv,
|
||||
inputs=self._app.user_id,
|
||||
outputs=[self.conversation_id, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.conversation_del_btn.click(
|
||||
self.delete_conv,
|
||||
inputs=[self.conversation_id, self._app.user_id],
|
||||
outputs=[self.conversation_id, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.conversation_rn_btn.click(
|
||||
self.rename_conv,
|
||||
inputs=[self.conversation_id, self.conversation_rn, self._app.user_id],
|
||||
outputs=[self.conversation, self.conversation],
|
||||
show_progress="hidden",
|
||||
)
|
||||
|
||||
def load_chat_history(self, user_id):
|
||||
"""Reload chat history"""
|
||||
options = []
|
||||
|
@ -112,7 +89,7 @@ class ConversationControl(BasePage):
|
|||
def reload_conv(self, user_id):
|
||||
conv_list = self.load_chat_history(user_id)
|
||||
if conv_list:
|
||||
return gr.update(value=conv_list[0][1], choices=conv_list)
|
||||
return gr.update(value=None, choices=conv_list)
|
||||
else:
|
||||
return gr.update(value=None, choices=[])
|
||||
|
||||
|
@ -133,10 +110,15 @@ class ConversationControl(BasePage):
|
|||
return id_, gr.update(value=id_, choices=history)
|
||||
|
||||
def delete_conv(self, conversation_id, user_id):
|
||||
"""Create new chat"""
|
||||
"""Delete the selected conversation"""
|
||||
if not conversation_id:
|
||||
gr.Warning("No conversation selected.")
|
||||
return None, gr.update()
|
||||
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return None, gr.update()
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
result = session.exec(statement).one()
|
||||
|
@ -161,6 +143,7 @@ class ConversationControl(BasePage):
|
|||
name = result.name
|
||||
selected = result.data_source.get("selected", {})
|
||||
chats = result.data_source.get("messages", [])
|
||||
info_panel = ""
|
||||
state = result.data_source.get("state", STATE)
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
|
@ -168,22 +151,36 @@ class ConversationControl(BasePage):
|
|||
name = ""
|
||||
selected = {}
|
||||
chats = []
|
||||
info_panel = ""
|
||||
state = STATE
|
||||
|
||||
indices = []
|
||||
for index in self._app.index_manager.indices:
|
||||
# assume that the index has selector
|
||||
if index.selector == -1:
|
||||
if index.selector is None:
|
||||
continue
|
||||
indices.append(selected.get(str(index.id), []))
|
||||
if isinstance(index.selector, int):
|
||||
indices.append(selected.get(str(index.id), []))
|
||||
if isinstance(index.selector, tuple):
|
||||
indices.extend(selected.get(str(index.id), [[]] * len(index.selector)))
|
||||
|
||||
return id_, id_, name, chats, state, *indices
|
||||
return id_, id_, name, chats, info_panel, state, *indices
|
||||
|
||||
def rename_conv(self, conversation_id, new_name, user_id):
|
||||
"""Rename the conversation"""
|
||||
if user_id is None:
|
||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||
return gr.update(), ""
|
||||
|
||||
if not conversation_id:
|
||||
gr.Warning("No conversation selected.")
|
||||
return gr.update(), ""
|
||||
|
||||
errors = is_conv_name_valid(new_name)
|
||||
if errors:
|
||||
gr.Warning(errors)
|
||||
return gr.update(), conversation_id
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||
result = session.exec(statement).one()
|
||||
|
|
|
@ -48,13 +48,19 @@ class ReportIssue(BasePage):
|
|||
chat_history: list,
|
||||
settings: dict,
|
||||
user_id: Optional[int],
|
||||
info_panel: str,
|
||||
chat_state: dict,
|
||||
*selecteds
|
||||
*selecteds,
|
||||
):
|
||||
selecteds_ = {}
|
||||
for index in self._app.index_manager.indices:
|
||||
if index.selector != -1:
|
||||
selecteds_[str(index.id)] = selecteds[index.selector]
|
||||
if index.selector is not None:
|
||||
if isinstance(index.selector, int):
|
||||
selecteds_[str(index.id)] = selecteds[index.selector]
|
||||
elif isinstance(index.selector, tuple):
|
||||
selecteds_[str(index.id)] = [selecteds[_] for _ in index.selector]
|
||||
else:
|
||||
print(f"Unknown selector type: {index.selector}")
|
||||
|
||||
with Session(engine) as session:
|
||||
issue = IssueReport(
|
||||
|
@ -66,6 +72,7 @@ class ReportIssue(BasePage):
|
|||
chat={
|
||||
"conv_id": conv_id,
|
||||
"chat_history": chat_history,
|
||||
"info_panel": info_panel,
|
||||
"chat_state": chat_state,
|
||||
"selecteds": selecteds_,
|
||||
},
|
||||
|
|
|
@ -31,11 +31,10 @@ class LoginPage(BasePage):
|
|||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
gr.Markdown("Welcome to Kotaemon")
|
||||
self.usn = gr.Textbox(label="Username")
|
||||
self.pwd = gr.Textbox(label="Password", type="password")
|
||||
self.btn_login = gr.Button("Login")
|
||||
self._dummy = gr.State()
|
||||
gr.Markdown("# Welcome to Kotaemon")
|
||||
self.usn = gr.Textbox(label="Username", visible=False)
|
||||
self.pwd = gr.Textbox(label="Password", type="password", visible=False)
|
||||
self.btn_login = gr.Button("Login", visible=False)
|
||||
|
||||
def on_register_events(self):
|
||||
onSignIn = gr.on(
|
||||
|
@ -45,24 +44,56 @@ class LoginPage(BasePage):
|
|||
outputs=[self._app.user_id, self.usn, self.pwd],
|
||||
show_progress="hidden",
|
||||
js=signin_js,
|
||||
).then(
|
||||
self.toggle_login_visibility,
|
||||
inputs=[self._app.user_id],
|
||||
outputs=[self.usn, self.pwd, self.btn_login],
|
||||
)
|
||||
for event in self._app.get_event("onSignIn"):
|
||||
onSignIn = onSignIn.success(**event)
|
||||
|
||||
def toggle_login_visibility(self, user_id):
|
||||
return (
|
||||
gr.update(visible=user_id is None),
|
||||
gr.update(visible=user_id is None),
|
||||
gr.update(visible=user_id is None),
|
||||
)
|
||||
|
||||
def _on_app_created(self):
|
||||
self._app.app.load(
|
||||
None,
|
||||
inputs=None,
|
||||
outputs=[self.usn, self.pwd],
|
||||
onSignIn = self._app.app.load(
|
||||
self.login,
|
||||
inputs=[self.usn, self.pwd],
|
||||
outputs=[self._app.user_id, self.usn, self.pwd],
|
||||
show_progress="hidden",
|
||||
js=fetch_creds,
|
||||
).then(
|
||||
self.toggle_login_visibility,
|
||||
inputs=[self._app.user_id],
|
||||
outputs=[self.usn, self.pwd, self.btn_login],
|
||||
)
|
||||
for event in self._app.get_event("onSignIn"):
|
||||
onSignIn = onSignIn.success(**event)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": self.toggle_login_visibility,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.usn, self.pwd, self.btn_login],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
def login(self, usn, pwd):
|
||||
if not usn or not pwd:
|
||||
return None, usn, pwd
|
||||
|
||||
hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
|
||||
with Session(engine) as session:
|
||||
stmt = select(User).where(
|
||||
User.username_lower == usn.lower(), User.password == hashed_password
|
||||
User.username_lower == usn.lower().strip(),
|
||||
User.password == hashed_password,
|
||||
)
|
||||
result = session.exec(stmt).all()
|
||||
if result:
|
||||
|
|
|
@ -164,9 +164,14 @@ class SettingsPage(BasePage):
|
|||
show_progress="hidden",
|
||||
)
|
||||
onSignOutClick = self.signout.click(
|
||||
lambda: (None, "Current user: ___"),
|
||||
lambda: (None, "Current user: ___", "", ""),
|
||||
inputs=None,
|
||||
outputs=[self._user_id, self.current_name],
|
||||
outputs=[
|
||||
self._user_id,
|
||||
self.current_name,
|
||||
self.password_change,
|
||||
self.password_change_confirm,
|
||||
],
|
||||
show_progress="hidden",
|
||||
js=signout_js,
|
||||
).then(
|
||||
|
@ -192,8 +197,12 @@ class SettingsPage(BasePage):
|
|||
self.password_change_btn = gr.Button("Change password", interactive=True)
|
||||
|
||||
def change_password(self, user_id, password, password_confirm):
|
||||
if password != password_confirm:
|
||||
gr.Warning("Password does not match")
|
||||
from ktem.pages.admin.user import validate_password
|
||||
|
||||
errors = validate_password(password, password_confirm)
|
||||
if errors:
|
||||
print(errors)
|
||||
gr.Warning(errors)
|
||||
return password, password_confirm
|
||||
|
||||
with Session(engine) as session:
|
||||
|
|
|
@ -34,12 +34,16 @@ class BaseReasoning(BaseComponent):
|
|||
|
||||
@classmethod
|
||||
def get_pipeline(
|
||||
cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None
|
||||
cls,
|
||||
user_settings: dict,
|
||||
state: dict,
|
||||
retrievers: Optional[list["BaseComponent"]] = None,
|
||||
) -> "BaseReasoning":
|
||||
"""Get the reasoning pipeline for the app to execute
|
||||
|
||||
Args:
|
||||
user_setting: user settings
|
||||
state: conversation state
|
||||
retrievers (list): List of retrievers
|
||||
"""
|
||||
return cls()
|
||||
|
|
|
@ -22,6 +22,8 @@ from kotaemon.indices.splitters import TokenSplitter
|
|||
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||
from kotaemon.loaders.utils.gpt4v import stream_gpt4v
|
||||
|
||||
from .base import BaseReasoning
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVIDENCE_MODE_TEXT = 0
|
||||
|
@ -204,7 +206,7 @@ class AnswerWithContextPipeline(BaseComponent):
|
|||
lang: str = "English" # support English and Japanese
|
||||
|
||||
async def run( # type: ignore
|
||||
self, question: str, evidence: str, evidence_mode: int = 0
|
||||
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||
) -> Document:
|
||||
"""Answer the question based on the evidence
|
||||
|
||||
|
@ -336,7 +338,7 @@ class RewriteQuestionPipeline(BaseComponent):
|
|||
return Document(text=output)
|
||||
|
||||
|
||||
class FullQAPipeline(BaseComponent):
|
||||
class FullQAPipeline(BaseReasoning):
|
||||
"""Question answering pipeline. Handle from question to answer"""
|
||||
|
||||
class Config:
|
||||
|
@ -352,6 +354,8 @@ class FullQAPipeline(BaseComponent):
|
|||
async def run( # type: ignore
|
||||
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
||||
) -> Document: # type: ignore
|
||||
import markdown
|
||||
|
||||
docs = []
|
||||
doc_ids = []
|
||||
if self.use_rewrite:
|
||||
|
@ -364,12 +368,16 @@ class FullQAPipeline(BaseComponent):
|
|||
docs.append(doc)
|
||||
doc_ids.append(doc.doc_id)
|
||||
for doc in docs:
|
||||
# TODO: a better approach to show the information
|
||||
text = markdown.markdown(
|
||||
doc.text, extensions=["markdown.extensions.tables"]
|
||||
)
|
||||
self.report_output(
|
||||
{
|
||||
"evidence": (
|
||||
"<details open>"
|
||||
f"<summary>{doc.metadata['file_name']}</summary>"
|
||||
f"{doc.text}"
|
||||
f"{text}"
|
||||
"</details><br>"
|
||||
)
|
||||
}
|
||||
|
@ -378,7 +386,12 @@ class FullQAPipeline(BaseComponent):
|
|||
|
||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||
answer = await self.answering_pipeline(
|
||||
question=message, evidence=evidence, evidence_mode=evidence_mode
|
||||
question=message,
|
||||
history=history,
|
||||
evidence=evidence,
|
||||
evidence_mode=evidence_mode,
|
||||
conv_id=conv_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# prepare citation
|
||||
|
@ -388,14 +401,29 @@ class FullQAPipeline(BaseComponent):
|
|||
for quote in fact_with_evidence.substring_quote:
|
||||
for doc in docs:
|
||||
start_idx = doc.text.find(quote)
|
||||
if start_idx >= 0:
|
||||
if start_idx == -1:
|
||||
continue
|
||||
|
||||
end_idx = start_idx + len(quote)
|
||||
|
||||
current_idx = start_idx
|
||||
if "|" not in doc.text[start_idx:end_idx]:
|
||||
spans[doc.doc_id].append(
|
||||
{
|
||||
"start": start_idx,
|
||||
"end": start_idx + len(quote),
|
||||
}
|
||||
{"start": start_idx, "end": end_idx}
|
||||
)
|
||||
break
|
||||
else:
|
||||
while doc.text[current_idx:end_idx].find("|") != -1:
|
||||
match_idx = doc.text[current_idx:end_idx].find("|")
|
||||
spans[doc.doc_id].append(
|
||||
{
|
||||
"start": current_idx,
|
||||
"end": current_idx + match_idx,
|
||||
}
|
||||
)
|
||||
current_idx += match_idx + 2
|
||||
if current_idx > end_idx:
|
||||
break
|
||||
break
|
||||
|
||||
id2docs = {doc.doc_id: doc for doc in docs}
|
||||
lack_evidence = True
|
||||
|
@ -414,12 +442,15 @@ class FullQAPipeline(BaseComponent):
|
|||
if idx < len(ss) - 1:
|
||||
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
||||
text += id2docs[id].text[ss[-1]["end"] :]
|
||||
text_out = markdown.markdown(
|
||||
text, extensions=["markdown.extensions.tables"]
|
||||
)
|
||||
self.report_output(
|
||||
{
|
||||
"evidence": (
|
||||
"<details open>"
|
||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||
f"{text}"
|
||||
f"{text_out}"
|
||||
"</details><br>"
|
||||
)
|
||||
}
|
||||
|
@ -434,12 +465,15 @@ class FullQAPipeline(BaseComponent):
|
|||
{"evidence": "Retrieved segments without matching evidence:\n"}
|
||||
)
|
||||
for id in list(not_detected):
|
||||
text_out = markdown.markdown(
|
||||
id2docs[id].text, extensions=["markdown.extensions.tables"]
|
||||
)
|
||||
self.report_output(
|
||||
{
|
||||
"evidence": (
|
||||
"<details>"
|
||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||
f"{id2docs[id].text}"
|
||||
f"{text_out}"
|
||||
"</details><br>"
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user