Fix UI bugs (#8)

* Auto create conversation when the user starts

* Add conversation rename rule check

* Fix empty name during save

* Confirm deleting conversation

* Show warning if users don't select file when upload files in the File Index

* Feedback when user uploads duplicated file

* Limit the file types

* Fix valid username

* Allow login when username with leading and trailing whitespaces

* Improve the user

* Disable admin panel for non-admnin user

* Refresh user lists after creating/deleting users

* Auto logging in

* Clear admin information upon signing out

* Fix unable to receive uploaded filename that include special characters, like !@#$%^&*().pdf

* Set upload validation for FileIndex

* Improve user management UI/UIX

* Show extraction error when indexing file

* Return selected user -1 when signing out

* Fix default supported file types in file index

* Validate changing password

* Allow the selector to contain mulitple gradio components

* A more tolerable placeholder screen

* Allow chat suggestion box

* Increase concurrency limit

* Make adobe loader optional

* Use BaseReasoning

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin 2024-04-03 16:33:54 +07:00 committed by GitHub
parent 43a18ba070
commit ecf09b275f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 936 additions and 255 deletions

1
.gitignore vendored
View File

@ -452,6 +452,7 @@ $RECYCLE.BIN/
.theflow/ .theflow/
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm # End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,vim,emacs,visualstudiocode,pycharm
*.py[coid]
logs/ logs/
.gitsecret/keys/random_seed .gitsecret/keys/random_seed

View File

@ -52,7 +52,12 @@ repos:
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
[types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"] [
types-PyYAML==6.0.12.11,
"types-requests",
"sqlmodel",
"types-Markdown",
]
args: ["--check-untyped-defs", "--ignore-missing-imports"] args: ["--check-untyped-defs", "--ignore-missing-imports"]
exclude: "^templates/" exclude: "^templates/"
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell

View File

@ -104,17 +104,15 @@ class CitationPipeline(BaseComponent):
print("CitationPipeline: invoking LLM") print("CitationPipeline: invoking LLM")
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs) llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
print("CitationPipeline: finish invoking LLM") print("CitationPipeline: finish invoking LLM")
except Exception as e:
print(e)
return None
if not llm_output.messages: if not llm_output.messages:
return None return None
function_output = llm_output.messages[0].additional_kwargs["function_call"][ function_output = llm_output.messages[0].additional_kwargs["function_call"][
"arguments" "arguments"
] ]
output = QuestionAnswer.parse_raw(function_output) output = QuestionAnswer.parse_raw(function_output)
except Exception as e:
print(e)
return None
return output return output

View File

@ -5,7 +5,7 @@ from .docx_loader import DocxReader
from .excel_loader import PandasExcelReader from .excel_loader import PandasExcelReader
from .html_loader import HtmlReader from .html_loader import HtmlReader
from .mathpix_loader import MathpixPDFReader from .mathpix_loader import MathpixPDFReader
from .ocr_loader import OCRReader from .ocr_loader import ImageReader, OCRReader
from .unstructured_loader import UnstructuredReader from .unstructured_loader import UnstructuredReader
__all__ = [ __all__ = [
@ -13,6 +13,7 @@ __all__ = [
"BaseReader", "BaseReader",
"PandasExcelReader", "PandasExcelReader",
"MathpixPDFReader", "MathpixPDFReader",
"ImageReader",
"OCRReader", "OCRReader",
"DirectoryReader", "DirectoryReader",
"UnstructuredReader", "UnstructuredReader",

View File

@ -10,14 +10,6 @@ from llama_index.readers.base import BaseReader
from kotaemon.base import Document from kotaemon.base import Document
from .utils.adobe import (
generate_figure_captions,
load_json,
parse_figure_paths,
parse_table_paths,
request_adobe_service,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_VLM_ENDPOINT = ( DEFAULT_VLM_ENDPOINT = (
@ -74,6 +66,13 @@ class AdobeReader(BaseReader):
includes 3 types: text, table, and image includes 3 types: text, table, and image
""" """
from .utils.adobe import (
generate_figure_captions,
load_json,
parse_figure_paths,
parse_table_paths,
request_adobe_service,
)
filename = file.name filename = file.name
filepath = str(Path(file).resolve()) filepath = str(Path(file).resolve())

View File

@ -125,3 +125,70 @@ class OCRReader(BaseReader):
) )
return documents return documents
class ImageReader(BaseReader):
"""Read PDF using OCR, with high focus on table extraction
Example:
```python
>> from knowledgehub.loaders import OCRReader
>> reader = OCRReader()
>> documents = reader.load_data("path/to/pdf")
```
Args:
endpoint: URL to FullOCR endpoint. If not provided, will look for
environment variable `OCR_READER_ENDPOINT` or use the default
`knowledgehub.loaders.ocr_loader.DEFAULT_OCR_ENDPOINT`
(http://127.0.0.1:8000/v2/ai/infer/)
use_ocr: whether to use OCR to read text (e.g: from images, tables) in the PDF
If False, only the table and text within table cells will be extracted.
"""
def __init__(self, endpoint: Optional[str] = None):
"""Init the OCR reader with OCR endpoint (FullOCR pipeline)"""
super().__init__()
self.ocr_endpoint = endpoint or os.getenv(
"OCR_READER_ENDPOINT", DEFAULT_OCR_ENDPOINT
)
def load_data(
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
) -> List[Document]:
"""Load data using OCR reader
Args:
file_path (Path): Path to PDF file
debug_path (Path): Path to store debug image output
artifact_path (Path): Path to OCR endpoints artifacts directory
Returns:
List[Document]: list of documents extracted from the PDF file
"""
file_path = Path(file_path).resolve()
with file_path.open("rb") as content:
files = {"input": content}
data = {"job_id": uuid4(), "table_only": False}
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
ocr_results = resp.json()["result"]
extra_info = extra_info or {}
result = []
for ocr_result in ocr_results:
result.append(
Document(
content=ocr_result["csv_string"],
metadata=extra_info,
)
)
return result

View File

@ -229,7 +229,9 @@ class BasePage:
def _on_app_created(self): def _on_app_created(self):
"""Called when the app is created""" """Called when the app is created"""
def as_gradio_component(self) -> Optional[gr.components.Component]: def as_gradio_component(
self,
) -> Optional[gr.components.Component | list[gr.components.Component]]:
"""Return the gradio components responsible for events """Return the gradio components responsible for events
Note: in ideal scenario, this method shouldn't be necessary. Note: in ideal scenario, this method shouldn't be necessary.

View File

@ -1,6 +1,6 @@
import abc import abc
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Any, Optional
if TYPE_CHECKING: if TYPE_CHECKING:
from ktem.app import BasePage from ktem.app import BasePage
@ -57,7 +57,7 @@ class BaseIndex(abc.ABC):
self._app = app self._app = app
self.id = id self.id = id
self.name = name self.name = name
self._config = config # admin settings self.config = config # admin settings
def on_create(self): def on_create(self):
"""Create the index for the first time""" """Create the index for the first time"""
@ -121,7 +121,7 @@ class BaseIndex(abc.ABC):
... ...
def get_retriever_pipelines( def get_retriever_pipelines(
self, settings: dict, selected: Optional[list] self, settings: dict, selected: Any = None
) -> list["BaseComponent"]: ) -> list["BaseComponent"]:
"""Return the retriever pipelines to retrieve the entity from the index""" """Return the retriever pipelines to retrieve the entity from the index"""
return [] return []

View File

@ -127,3 +127,11 @@ class BaseFileIndexIndexing(BaseComponent):
the absolute file storage path to the file the absolute file storage path to the file
""" """
raise NotImplementedError raise NotImplementedError
def warning(self, msg):
"""Log a warning message
Args:
msg: the message to log
"""
print(msg)

View File

@ -13,7 +13,6 @@ from theflow.utils.modules import import_dotted_string
from kotaemon.storages import BaseDocumentStore, BaseVectorStore from kotaemon.storages import BaseDocumentStore, BaseVectorStore
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .ui import FileIndexPage, FileSelector
class FileIndex(BaseIndex): class FileIndex(BaseIndex):
@ -77,9 +76,15 @@ class FileIndex(BaseIndex):
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing] self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]] self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
self._selector_ui_cls: Type
self._selector_ui: Any = None
self._index_ui_cls: Type
self._index_ui: Any = None
self._setup_indexing_cls() self._setup_indexing_cls()
self._setup_retriever_cls() self._setup_retriever_cls()
self._setup_file_index_ui_cls()
self._setup_file_selector_ui_cls()
self._default_settings: dict[str, dict] = {} self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {} self._setting_mappings: dict[str, dict] = {}
@ -91,14 +96,14 @@ class FileIndex(BaseIndex):
The indexing class will is retrieved from the following order. Stop at the The indexing class will is retrieved from the following order. Stop at the
first order found: first order found:
- `FILE_INDEX_PIPELINE` in self._config - `FILE_INDEX_PIPELINE` in self.config
- `FILE_INDEX_{id}_PIPELINE` in the flowsettings - `FILE_INDEX_{id}_PIPELINE` in the flowsettings
- `FILE_INDEX_PIPELINE` in the flowsettings - `FILE_INDEX_PIPELINE` in the flowsettings
- The default .pipelines.IndexDocumentPipeline - The default .pipelines.IndexDocumentPipeline
""" """
if "FILE_INDEX_PIPELINE" in self._config: if "FILE_INDEX_PIPELINE" in self.config:
self._indexing_pipeline_cls = import_dotted_string( self._indexing_pipeline_cls = import_dotted_string(
self._config["FILE_INDEX_PIPELINE"], safe=False self.config["FILE_INDEX_PIPELINE"], safe=False
) )
return return
@ -125,15 +130,15 @@ class FileIndex(BaseIndex):
The retriever classes will is retrieved from the following order. Stop at the The retriever classes will is retrieved from the following order. Stop at the
first order found: first order found:
- `FILE_INDEX_RETRIEVER_PIPELINES` in self._config - `FILE_INDEX_RETRIEVER_PIPELINES` in self.config
- `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings - `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings
- `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings - `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings
- The default .pipelines.DocumentRetrievalPipeline - The default .pipelines.DocumentRetrievalPipeline
""" """
if "FILE_INDEX_RETRIEVER_PIPELINES" in self._config: if "FILE_INDEX_RETRIEVER_PIPELINES" in self.config:
self._retriever_pipeline_cls = [ self._retriever_pipeline_cls = [
import_dotted_string(each, safe=False) import_dotted_string(each, safe=False)
for each in self._config["FILE_INDEX_RETRIEVER_PIPELINES"] for each in self.config["FILE_INDEX_RETRIEVER_PIPELINES"]
] ]
return return
@ -157,6 +162,76 @@ class FileIndex(BaseIndex):
self._retriever_pipeline_cls = [DocumentRetrievalPipeline] self._retriever_pipeline_cls = [DocumentRetrievalPipeline]
def _setup_file_selector_ui_cls(self):
"""Retrieve the file selector UI for the file index
There can be multiple retriever classes.
The retriever classes will is retrieved from the following order. Stop at the
first order found:
- `FILE_INDEX_SELECTOR_UI` in self.config
- `FILE_INDEX_{id}_SELECTOR_UI` in the flowsettings
- `FILE_INDEX_SELECTOR_UI` in the flowsettings
- The default .ui.FileSelector
"""
if "FILE_INDEX_SELECTOR_UI" in self.config:
self._selector_ui_cls = import_dotted_string(
self.config["FILE_INDEX_SELECTOR_UI"], safe=False
)
return
if hasattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"):
self._selector_ui_cls = import_dotted_string(
getattr(flowsettings, f"FILE_INDEX_{self.id}_SELECTOR_UI"),
safe=False,
)
return
if hasattr(flowsettings, "FILE_INDEX_SELECTOR_UI"):
self._selector_ui_cls = import_dotted_string(
getattr(flowsettings, "FILE_INDEX_SELECTOR_UI"), safe=False
)
return
from .ui import FileSelector
self._selector_ui_cls = FileSelector
def _setup_file_index_ui_cls(self):
"""Retrieve the Index UI class
There can be multiple retriever classes.
The retriever classes will is retrieved from the following order. Stop at the
first order found:
- `FILE_INDEX_UI` in self.config
- `FILE_INDEX_{id}_UI` in the flowsettings
- `FILE_INDEX_UI` in the flowsettings
- The default .ui.FileIndexPage
"""
if "FILE_INDEX_UI" in self.config:
self._index_ui_cls = import_dotted_string(
self.config["FILE_INDEX_UI"], safe=False
)
return
if hasattr(flowsettings, f"FILE_INDEX_{self.id}_UI"):
self._index_ui_cls = import_dotted_string(
getattr(flowsettings, f"FILE_INDEX_{self.id}_UI"),
safe=False,
)
return
if hasattr(flowsettings, "FILE_INDEX_UI"):
self._index_ui_cls = import_dotted_string(
getattr(flowsettings, "FILE_INDEX_UI"), safe=False
)
return
from .ui import FileIndexPage
self._index_ui_cls = FileIndexPage
def on_create(self): def on_create(self):
"""Create the index for the first time """Create the index for the first time
@ -165,6 +240,13 @@ class FileIndex(BaseIndex):
2. Create the vectorstore 2. Create the vectorstore
3. Create the docstore 3. Create the docstore
""" """
file_types_str = self.config.get(
"supported_file_types",
self.get_admin_settings()["supported_file_types"]["value"],
)
file_types = [each.strip() for each in file_types_str.split(",")]
self.config["supported_file_types"] = file_types
self._resources["Source"].metadata.create_all(engine) # type: ignore self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore self._resources["Index"].metadata.create_all(engine) # type: ignore
self._fs_path.mkdir(parents=True, exist_ok=True) self._fs_path.mkdir(parents=True, exist_ok=True)
@ -180,10 +262,14 @@ class FileIndex(BaseIndex):
shutil.rmtree(self._fs_path) shutil.rmtree(self._fs_path)
def get_selector_component_ui(self): def get_selector_component_ui(self):
return FileSelector(self._app, self) if self._selector_ui is None:
self._selector_ui = self._selector_ui_cls(self._app, self)
return self._selector_ui
def get_index_page_ui(self): def get_index_page_ui(self):
return FileIndexPage(self._app, self) if self._index_ui is None:
self._index_ui = self._index_ui_cls(self._app, self)
return self._index_ui
def get_user_settings(self): def get_user_settings(self):
if self._default_settings: if self._default_settings:
@ -210,7 +296,31 @@ class FileIndex(BaseIndex):
"value": embedding_default, "value": embedding_default,
"component": "dropdown", "component": "dropdown",
"choices": embedding_choices, "choices": embedding_choices,
} },
"supported_file_types": {
"name": "Supported file types",
"value": (
"image, .pdf, .txt, .csv, .xlsx, .doc, .docx, .pptx, .html, .zip"
),
"component": "text",
},
"max_file_size": {
"name": "Max file size (MB) - set 0 to disable",
"value": 1000,
"component": "number",
},
"max_number_of_files": {
"name": "Max number of files that can be indexed - set 0 to disable",
"value": 0,
"component": "number",
},
"max_number_of_text_length": {
"name": (
"Max amount of characters that can be indexed - set 0 to disable"
),
"value": 0,
"component": "number",
},
} }
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing: def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
@ -224,14 +334,15 @@ class FileIndex(BaseIndex):
else: else:
stripped_settings[key] = value stripped_settings[key] = value
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self._config) obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
obj.set_resources(resources=self._resources) obj.set_resources(resources=self._resources)
return obj return obj
def get_retriever_pipelines( def get_retriever_pipelines(
self, settings: dict, selected: Optional[list] = None self, settings: dict, selected: Any = None
) -> list["BaseFileIndexRetriever"]: ) -> list["BaseFileIndexRetriever"]:
# retrieval settings
prefix = f"index.options.{self.id}." prefix = f"index.options.{self.id}."
stripped_settings = {} stripped_settings = {}
for key, value in settings.items(): for key, value in settings.items():
@ -240,9 +351,12 @@ class FileIndex(BaseIndex):
else: else:
stripped_settings[key] = value stripped_settings[key] = value
# transform selected id
selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)
retrievers = [] retrievers = []
for cls in self._retriever_pipeline_cls: for cls in self._retriever_pipeline_cls:
obj = cls.get_pipeline(stripped_settings, self._config, selected) obj = cls.get_pipeline(stripped_settings, self.config, selected_ids)
if obj is None: if obj is None:
continue continue
obj.set_resources(self._resources) obj.set_resources(self._resources)

View File

@ -9,6 +9,7 @@ from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import gradio as gr
from ktem.components import embeddings, filestorage_path from ktem.components import embeddings, filestorage_path
from ktem.db.models import engine from ktem.db.models import engine
from llama_index.vector_stores import ( from llama_index.vector_stores import (
@ -18,7 +19,7 @@ from llama_index.vector_stores import (
MetadataFilters, MetadataFilters,
) )
from llama_index.vector_stores.types import VectorStoreQueryMode from llama_index.vector_stores.types import VectorStoreQueryMode
from sqlalchemy import select from sqlalchemy import delete, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from theflow.settings import settings from theflow.settings import settings
from theflow.utils.modules import import_dotted_string from theflow.utils.modules import import_dotted_string
@ -279,6 +280,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
to_index: list[str] = [] to_index: list[str] = []
file_to_hash: dict[str, str] = {} file_to_hash: dict[str, str] = {}
errors = [] errors = []
to_update = []
for file_path in file_paths: for file_path in file_paths:
abs_path = str(Path(file_path).resolve()) abs_path = str(Path(file_path).resolve())
@ -291,16 +293,26 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
statement = select(Source).where(Source.name == Path(abs_path).name) statement = select(Source).where(Source.name == Path(abs_path).name)
item = session.execute(statement).first() item = session.execute(statement).first()
if item and not reindex: if item:
if not reindex:
errors.append(Path(abs_path).name) errors.append(Path(abs_path).name)
continue continue
else:
to_update.append(Path(abs_path).name)
to_index.append(abs_path) to_index.append(abs_path)
if errors: if errors:
error_files = ", ".join(errors)
if len(error_files) > 100:
error_files = error_files[:80] + "..."
print( print(
"Files already exist. Please rename/remove them or enable reindex.\n" "Skip these files already exist. Please rename/remove them or "
f"{errors}" f"enable reindex:\n{errors}"
)
self.warning(
"Skip these files already exist. Please rename/remove them or "
f"enable reindex:\n{error_files}"
) )
if not to_index: if not to_index:
@ -310,9 +322,19 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
for path in to_index: for path in to_index:
shutil.copy(path, filestorage_path / file_to_hash[path]) shutil.copy(path, filestorage_path / file_to_hash[path])
# prepare record info # extract the file & prepare record info
file_to_source: dict = {} file_to_source: dict = {}
extraction_errors = []
nodes = []
for file_path, file_hash in file_to_hash.items(): for file_path, file_hash in file_to_hash.items():
if str(Path(file_path).resolve()) not in to_index:
continue
extraction_result = self.file_ingestor(file_path)
if not extraction_result:
extraction_errors.append(Path(file_path).name)
continue
nodes.extend(extraction_result)
source = Source( source = Source(
name=Path(file_path).name, name=Path(file_path).name,
path=file_hash, path=file_hash,
@ -320,9 +342,23 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
) )
file_to_source[file_path] = source file_to_source[file_path] = source
# extract the files if extraction_errors:
nodes = self.file_ingestor(to_index) msg = "Failed to extract these files: {}".format(
print("Extracted", len(to_index), "files into", len(nodes), "nodes") ", ".join(extraction_errors)
)
print(msg)
self.warning(msg)
if not nodes:
return [], []
print(
"Extracted",
len(to_index) - len(extraction_errors),
"files into",
len(nodes),
"nodes",
)
# index the files # index the files
print("Indexing the files into vector store") print("Indexing the files into vector store")
@ -332,7 +368,11 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
# persist to the index # persist to the index
print("Persisting the vector and the document into index") print("Persisting the vector and the document into index")
file_ids = [] file_ids = []
to_update = list(set(to_update))
with Session(engine) as session: with Session(engine) as session:
if to_update:
session.execute(delete(Source).where(Source.name.in_(to_update)))
for source in file_to_source.values(): for source in file_to_source.values():
session.add(source) session.add(source)
session.commit() session.commit()
@ -404,3 +444,6 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
super().set_resources(resources) super().set_resources(resources)
self.indexing_vector_pipeline.vector_store = self._VS self.indexing_vector_pipeline.vector_store = self._VS
self.indexing_vector_pipeline.doc_store = self._DS self.indexing_vector_pipeline.doc_store = self._DS
def warning(self, msg):
gr.Warning(msg)

View File

@ -1,29 +1,48 @@
import os import os
import tempfile import tempfile
from pathlib import Path
import gradio as gr import gradio as gr
import pandas as pd import pandas as pd
from gradio.data_classes import FileData
from gradio.utils import NamedString
from ktem.app import BasePage from ktem.app import BasePage
from ktem.db.engine import engine from ktem.db.engine import engine
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
class File(gr.File):
"""Subclass from gr.File to maintain the original filename
The issue happens when user uploads file with name like: !@#$%%^&*().pdf
"""
def _process_single_file(self, f: FileData) -> NamedString | bytes:
file_name = f.path
if self.type == "filepath":
if f.orig_name and Path(file_name).name != f.orig_name:
file_name = str(Path(file_name).parent / f.orig_name)
os.rename(f.path, file_name)
file = tempfile.NamedTemporaryFile(delete=False, dir=self.GRADIO_CACHE)
file.name = file_name
return NamedString(file_name)
elif self.type == "binary":
with open(file_name, "rb") as file_data:
return file_data.read()
else:
raise ValueError(
"Unknown type: "
+ str(type)
+ ". Please choose from: 'filepath', 'binary'."
)
class DirectoryUpload(BasePage): class DirectoryUpload(BasePage):
def __init__(self, app): def __init__(self, app, index):
self._app = app super().__init__(app)
self._supported_file_types = [ self._index = index
"image", self._supported_file_types = self._index.config.get("supported_file_types", [])
".pdf",
".txt",
".csv",
".xlsx",
".doc",
".docx",
".pptx",
".html",
".zip",
]
self.on_building_ui() self.on_building_ui()
def on_building_ui(self): def on_building_ui(self):
@ -50,18 +69,7 @@ class FileIndexPage(BasePage):
def __init__(self, app, index): def __init__(self, app, index):
super().__init__(app) super().__init__(app)
self._index = index self._index = index
self._supported_file_types = [ self._supported_file_types = self._index.config.get("supported_file_types", [])
"image",
".pdf",
".txt",
".csv",
".xlsx",
".doc",
".docx",
".pptx",
".html",
".zip",
]
self.selected_panel_false = "Selected file: (please select above)" self.selected_panel_false = "Selected file: (please select above)"
self.selected_panel_true = "Selected file: {name}" self.selected_panel_true = "Selected file: {name}"
# TODO: on_building_ui is not correctly named if it's always called in # TODO: on_building_ui is not correctly named if it's always called in
@ -69,13 +77,32 @@ class FileIndexPage(BasePage):
self.public_events = [f"onFileIndex{index.id}Changed"] self.public_events = [f"onFileIndex{index.id}Changed"]
self.on_building_ui() self.on_building_ui()
def upload_instruction(self) -> str:
msgs = []
if self._supported_file_types:
msgs.append(
f"- Supported file types: {', '.join(self._supported_file_types)}"
)
if max_file_size := self._index.config.get("max_file_size", 0):
msgs.append(f"- Maximum file size: {max_file_size} MB")
if max_number_of_files := self._index.config.get("max_number_of_files", 0):
msgs.append(f"- The index can have maximum {max_number_of_files} files")
if msgs:
return "\n".join(msgs)
return ""
def on_building_ui(self): def on_building_ui(self):
"""Build the UI of the app""" """Build the UI of the app"""
with gr.Accordion(label="File upload", open=False): with gr.Accordion(label="File upload", open=True) as self.upload:
gr.Markdown( msg = self.upload_instruction()
f"Supported file types: {', '.join(self._supported_file_types)}", if msg:
) gr.Markdown(msg)
self.files = gr.File(
self.files = File(
file_types=self._supported_file_types, file_types=self._supported_file_types,
file_count="multiple", file_count="multiple",
container=False, container=False,
@ -98,18 +125,20 @@ class FileIndexPage(BasePage):
interactive=False, interactive=False,
) )
with gr.Row(): with gr.Row() as self.selection_info:
self.selected_file_id = gr.State(value=None) self.selected_file_id = gr.State(value=None)
self.selected_panel = gr.Markdown(self.selected_panel_false) self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button("Deselect", visible=False) self.deselect_button = gr.Button("Deselect", visible=False)
with gr.Row(): with gr.Row() as self.tools:
with gr.Column(): with gr.Column():
self.view_button = gr.Button("View Text (WIP)") self.view_button = gr.Button("View Text (WIP)")
with gr.Column(): with gr.Column():
self.delete_button = gr.Button("Delete") self.delete_button = gr.Button("Delete")
with gr.Row(): with gr.Row():
self.delete_yes = gr.Button("Confirm Delete", visible=False) self.delete_yes = gr.Button(
"Confirm Delete", variant="primary", visible=False
)
self.delete_no = gr.Button("Cancel", visible=False) self.delete_no = gr.Button("Cancel", visible=False)
def on_subscribe_public_events(self): def on_subscribe_public_events(self):
@ -242,10 +271,12 @@ class FileIndexPage(BasePage):
self._app.settings_state, self._app.settings_state,
], ],
outputs=[self.file_output], outputs=[self.file_output],
concurrency_limit=20,
).then( ).then(
fn=self.list_file, fn=self.list_file,
inputs=None, inputs=None,
outputs=[self.file_list_state, self.file_list], outputs=[self.file_list_state, self.file_list],
concurrency_limit=20,
) )
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"): for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onUploaded = onUploaded.then(**event) onUploaded = onUploaded.then(**event)
@ -274,6 +305,15 @@ class FileIndexPage(BasePage):
selected_files: the list of files already selected selected_files: the list of files already selected
settings: the settings of the app settings: the settings of the app
""" """
if not files:
gr.Info("No uploaded file")
return gr.update()
errors = self.validate(files)
if errors:
gr.Warning(", ".join(errors))
return gr.update()
gr.Info(f"Start indexing {len(files)} files...") gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline # get the pipeline
@ -409,6 +449,35 @@ class FileIndexPage(BasePage):
name=list_files["name"][ev.index[0]] name=list_files["name"][ev.index[0]]
) )
def validate(self, files: list[str]):
"""Validate if the files are valid"""
paths = [Path(file) for file in files]
errors = []
if max_file_size := self._index.config.get("max_file_size", 0):
errors_max_size = []
for path in paths:
if path.stat().st_size > max_file_size * 1e6:
errors_max_size.append(path.name)
if errors_max_size:
str_errors = ", ".join(errors_max_size)
if len(str_errors) > 60:
str_errors = str_errors[:55] + "..."
errors.append(
f"Maximum file size ({max_file_size} MB) exceeded: {str_errors}"
)
if max_number_of_files := self._index.config.get("max_number_of_files", 0):
with Session(engine) as session:
current_num_files = session.query(
self._index._db_tables["Source"].id
).count()
if len(paths) + current_num_files > max_number_of_files:
errors.append(
f"Maximum number of files ({max_number_of_files}) will be exceeded"
)
return errors
class FileSelector(BasePage): class FileSelector(BasePage):
"""File selector UI in the Chat page""" """File selector UI in the Chat page"""
@ -430,6 +499,9 @@ class FileSelector(BasePage):
def as_gradio_component(self): def as_gradio_component(self):
return self.selector return self.selector
def get_selected_ids(self, selected):
return selected
def load_files(self, selected_files): def load_files(self, selected_files):
options = [] options = []
available_ids = [] available_ids = []

View File

@ -1,4 +1,4 @@
from typing import Type from typing import Optional, Type
from ktem.db.models import engine from ktem.db.models import engine
from sqlmodel import Session, select from sqlmodel import Session, select
@ -49,15 +49,19 @@ class IndexManager:
Returns: Returns:
BaseIndex: the index object BaseIndex: the index object
""" """
index_cls = import_dotted_string(index_type, safe=False)
index = index_cls(app=self._app, id=id, name=name, config=config)
index.on_create()
with Session(engine) as session: with Session(engine) as session:
index_entry = Index(id=id, name=name, config=config, index_type=index_type) index_entry = Index(
id=index.id, name=index.name, config=index.config, index_type=index_type
)
session.add(index_entry) session.add(index_entry)
session.commit() session.commit()
session.refresh(index_entry) session.refresh(index_entry)
index_cls = import_dotted_string(index_type, safe=False) index.id = index_entry.id
index = index_cls(app=self._app, id=id, name=name, config=config)
index.on_create()
return index return index
@ -77,7 +81,7 @@ class IndexManager:
self._indices.append(index) self._indices.append(index)
return index return index
def exists(self, id: int) -> bool: def exists(self, id: Optional[int] = None, name: Optional[str] = None) -> bool:
"""Check if the index exists """Check if the index exists
Args: Args:
@ -86,10 +90,20 @@ class IndexManager:
Returns: Returns:
bool: True if the index exists, False otherwise bool: True if the index exists, False otherwise
""" """
if id:
with Session(engine) as session: with Session(engine) as session:
index = session.get(Index, id) index = session.get(Index, id)
return index is not None return index is not None
if name:
with Session(engine) as session:
index = session.exec(
select(Index).where(Index.name == name)
).one_or_none()
return index is not None
return False
def on_application_startup(self): def on_application_startup(self):
"""This method is called by the base application when the application starts """This method is called by the base application when the application starts

View File

@ -27,7 +27,7 @@ class App(BaseApp):
if self.f_user_management: if self.f_user_management:
from ktem.pages.login import LoginPage from ktem.pages.login import LoginPage
with gr.Tab("Login", elem_id="login-tab") as self._tabs["login-tab"]: with gr.Tab("Welcome", elem_id="login-tab") as self._tabs["login-tab"]:
self.login_page = LoginPage(self) self.login_page = LoginPage(self)
with gr.Tab( with gr.Tab(
@ -62,6 +62,9 @@ class App(BaseApp):
def on_subscribe_public_events(self): def on_subscribe_public_events(self):
if self.f_user_management: if self.f_user_management:
from ktem.db.engine import engine
from ktem.db.models import User
from sqlmodel import Session, select
def signed_in_out(user_id): def signed_in_out(user_id):
if not user_id: if not user_id:
@ -73,15 +76,32 @@ class App(BaseApp):
) )
for k in self._tabs.keys() for k in self._tabs.keys()
) )
with Session(engine) as session:
user = session.exec(select(User).where(User.id == user_id)).first()
if user is None:
return list( return list(
( (
gr.update(visible=True) gr.update(visible=True)
if k != "login-tab" if k == "login-tab"
else gr.update(visible=False) else gr.update(visible=False)
) )
for k in self._tabs.keys() for k in self._tabs.keys()
) )
is_admin = user.admin
tabs_update = []
for k in self._tabs.keys():
if k == "login-tab":
tabs_update.append(gr.update(visible=False))
elif k == "admin-tab":
tabs_update.append(gr.update(visible=is_admin))
else:
tabs_update.append(gr.update(visible=True))
return tabs_update
self.subscribe_event( self.subscribe_event(
name="onSignIn", name="onSignIn",
definition={ definition={

View File

@ -40,7 +40,7 @@ def validate_username(usn):
if len(usn) > 32: if len(usn) > 32:
errors.append("Username must be at most 32 characters long") errors.append("Username must be at most 32 characters long")
if not usn.strip("_").isalnum(): if not usn.replace("_", "").isalnum():
errors.append( errors.append(
"Username must contain only alphanumeric characters and underscores" "Username must contain only alphanumeric characters and underscores"
) )
@ -97,8 +97,6 @@ def validate_password(pwd, pwd_cnf):
class UserManagement(BasePage): class UserManagement(BasePage):
def __init__(self, app): def __init__(self, app):
self._app = app self._app = app
self.selected_panel_false = "Selected user: (please select above)"
self.selected_panel_true = "Selected user: {name}"
self.on_building_ui() self.on_building_ui()
if hasattr(flowsettings, "KH_FEATURE_USER_MANAGEMENT_ADMIN") and hasattr( if hasattr(flowsettings, "KH_FEATURE_USER_MANAGEMENT_ADMIN") and hasattr(
@ -126,7 +124,38 @@ class UserManagement(BasePage):
gr.Info(f'User "{usn}" created successfully') gr.Info(f'User "{usn}" created successfully')
def on_building_ui(self): def on_building_ui(self):
with gr.Accordion(label="Create user", open=False): with gr.Tab(label="User list"):
self.state_user_list = gr.State(value=None)
self.user_list = gr.DataFrame(
headers=["id", "name", "admin"],
interactive=False,
)
with gr.Group(visible=False) as self._selected_panel:
self.selected_user_id = gr.Number(value=-1, visible=False)
self.usn_edit = gr.Textbox(label="Username")
with gr.Row():
self.pwd_edit = gr.Textbox(label="Change password", type="password")
self.pwd_cnf_edit = gr.Textbox(
label="Confirm change password",
type="password",
)
self.admin_edit = gr.Checkbox(label="Admin")
with gr.Row() as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button("Save")
with gr.Column():
self.btn_delete = gr.Button("Delete")
with gr.Row():
self.btn_delete_yes = gr.Button(
"Confirm delete", variant="primary", visible=False
)
self.btn_delete_no = gr.Button("Cancel", visible=False)
with gr.Column():
self.btn_close = gr.Button("Close")
with gr.Tab(label="Create user"):
self.usn_new = gr.Textbox(label="Username", interactive=True) self.usn_new = gr.Textbox(label="Username", interactive=True)
self.pwd_new = gr.Textbox( self.pwd_new = gr.Textbox(
label="Password", type="password", interactive=True label="Password", type="password", interactive=True
@ -139,52 +168,28 @@ class UserManagement(BasePage):
gr.Markdown(PASSWORD_RULE) gr.Markdown(PASSWORD_RULE)
self.btn_new = gr.Button("Create user") self.btn_new = gr.Button("Create user")
gr.Markdown("## User list")
self.btn_list_user = gr.Button("Refresh user list")
self.state_user_list = gr.State(value=None)
self.user_list = gr.DataFrame(
headers=["id", "name", "admin"],
interactive=False,
)
with gr.Row():
self.selected_user_id = gr.State(value=None)
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button("Deselect", visible=False)
with gr.Group():
self.btn_delete = gr.Button("Delete user")
with gr.Row():
self.btn_delete_yes = gr.Button("Confirm", visible=False)
self.btn_delete_no = gr.Button("Cancel", visible=False)
gr.Markdown("## User details")
self.usn_edit = gr.Textbox(label="Username")
self.pwd_edit = gr.Textbox(label="Password", type="password")
self.pwd_cnf_edit = gr.Textbox(label="Confirm password", type="password")
self.admin_edit = gr.Checkbox(label="Admin")
self.btn_edit_save = gr.Button("Save")
def on_register_events(self): def on_register_events(self):
self.btn_new.click( self.btn_new.click(
self.create_user, self.create_user,
inputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new], inputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
outputs=None, outputs=[self.usn_new, self.pwd_new, self.pwd_cnf_new],
) ).then(
self.btn_list_user.click( self.list_users,
self.list_users, inputs=None, outputs=[self.state_user_list, self.user_list] inputs=self._app.user_id,
outputs=[self.state_user_list, self.user_list],
) )
self.user_list.select( self.user_list.select(
self.select_user, self.select_user,
inputs=self.user_list, inputs=self.user_list,
outputs=[self.selected_user_id, self.selected_panel], outputs=[self.selected_user_id],
show_progress="hidden", show_progress="hidden",
) )
self.selected_panel.change( self.selected_user_id.change(
self.on_selected_user_change, self.on_selected_user_change,
inputs=[self.selected_user_id], inputs=[self.selected_user_id],
outputs=[ outputs=[
self.deselect_button, self._selected_panel,
self._selected_panel_btn,
# delete section # delete section
self.btn_delete, self.btn_delete,
self.btn_delete_yes, self.btn_delete_yes,
@ -197,12 +202,6 @@ class UserManagement(BasePage):
], ],
show_progress="hidden", show_progress="hidden",
) )
self.deselect_button.click(
lambda: (None, self.selected_panel_false),
inputs=None,
outputs=[self.selected_user_id, self.selected_panel],
show_progress="hidden",
)
self.btn_delete.click( self.btn_delete.click(
self.on_btn_delete_click, self.on_btn_delete_click,
inputs=[self.selected_user_id], inputs=[self.selected_user_id],
@ -211,9 +210,13 @@ class UserManagement(BasePage):
) )
self.btn_delete_yes.click( self.btn_delete_yes.click(
self.delete_user, self.delete_user,
inputs=[self.selected_user_id], inputs=[self._app.user_id, self.selected_user_id],
outputs=[self.selected_user_id, self.selected_panel], outputs=[self.selected_user_id],
show_progress="hidden", show_progress="hidden",
).then(
self.list_users,
inputs=self._app.user_id,
outputs=[self.state_user_list, self.user_list],
) )
self.btn_delete_no.click( self.btn_delete_no.click(
lambda: ( lambda: (
@ -234,21 +237,53 @@ class UserManagement(BasePage):
self.pwd_cnf_edit, self.pwd_cnf_edit,
self.admin_edit, self.admin_edit,
], ],
outputs=None, outputs=[self.pwd_edit, self.pwd_cnf_edit],
show_progress="hidden", show_progress="hidden",
).then(
self.list_users,
inputs=self._app.user_id,
outputs=[self.state_user_list, self.user_list],
)
self.btn_close.click(
lambda: -1,
outputs=[self.selected_user_id],
)
def on_subscribe_public_events(self):
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.list_users,
"inputs": [self._app.user_id],
"outputs": [self.state_user_list, self.user_list],
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": lambda: ("", "", "", None, None, -1),
"outputs": [
self.usn_new,
self.pwd_new,
self.pwd_cnf_new,
self.state_user_list,
self.user_list,
self.selected_user_id,
],
},
) )
def create_user(self, usn, pwd, pwd_cnf): def create_user(self, usn, pwd, pwd_cnf):
errors = validate_username(usn) errors = validate_username(usn)
if errors: if errors:
gr.Warning(errors) gr.Warning(errors)
return return usn, pwd, pwd_cnf
errors = validate_password(pwd, pwd_cnf) errors = validate_password(pwd, pwd_cnf)
print(errors) print(errors)
if errors: if errors:
gr.Warning(errors) gr.Warning(errors)
return return usn, pwd, pwd_cnf
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.username_lower == usn.lower()) statement = select(User).where(User.username_lower == usn.lower())
@ -265,8 +300,22 @@ class UserManagement(BasePage):
session.commit() session.commit()
gr.Info(f'User "{usn}" created successfully') gr.Info(f'User "{usn}" created successfully')
def list_users(self): return "", "", ""
def list_users(self, user_id):
if user_id is None:
return [], pd.DataFrame.from_records(
[{"id": "-", "username": "-", "admin": "-"}]
)
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.id == user_id)
user = session.exec(statement).one()
if not user.admin:
return [], pd.DataFrame.from_records(
[{"id": "-", "username": "-", "admin": "-"}]
)
statement = select(User) statement = select(User)
results = [ results = [
{"id": user.id, "username": user.username, "admin": user.admin} {"id": user.id, "username": user.username, "admin": user.admin}
@ -284,18 +333,17 @@ class UserManagement(BasePage):
def select_user(self, user_list, ev: gr.SelectData): def select_user(self, user_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0: if ev.value == "-" and ev.index[0] == 0:
gr.Info("No user is loaded. Please refresh the user list") gr.Info("No user is loaded. Please refresh the user list")
return None, self.selected_panel_false return -1
if not ev.selected: if not ev.selected:
return None, self.selected_panel_false return -1
return user_list["id"][ev.index[0]], self.selected_panel_true.format( return user_list["id"][ev.index[0]]
name=user_list["username"][ev.index[0]]
)
def on_selected_user_change(self, selected_user_id): def on_selected_user_change(self, selected_user_id):
if selected_user_id is None: if selected_user_id == -1:
deselect_button = gr.update(visible=False) _selected_panel = gr.update(visible=False)
_selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True) btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False) btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False) btn_delete_no = gr.update(visible=False)
@ -304,7 +352,8 @@ class UserManagement(BasePage):
pwd_cnf_edit = gr.update(value="") pwd_cnf_edit = gr.update(value="")
admin_edit = gr.update(value=False) admin_edit = gr.update(value=False)
else: else:
deselect_button = gr.update(visible=True) _selected_panel = gr.update(visible=True)
_selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True) btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False) btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False) btn_delete_no = gr.update(visible=False)
@ -319,7 +368,8 @@ class UserManagement(BasePage):
admin_edit = gr.update(value=user.admin) admin_edit = gr.update(value=user.admin)
return ( return (
deselect_button, _selected_panel,
_selected_panel_btn,
btn_delete, btn_delete,
btn_delete_yes, btn_delete_yes,
btn_delete_no, btn_delete_no,
@ -344,17 +394,16 @@ class UserManagement(BasePage):
return btn_delete, btn_delete_yes, btn_delete_no return btn_delete, btn_delete_yes, btn_delete_no
def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin): def save_user(self, selected_user_id, usn, pwd, pwd_cnf, admin):
if usn:
errors = validate_username(usn) errors = validate_username(usn)
if errors: if errors:
gr.Warning(errors) gr.Warning(errors)
return return pwd, pwd_cnf
if pwd: if pwd:
errors = validate_password(pwd, pwd_cnf) errors = validate_password(pwd, pwd_cnf)
if errors: if errors:
gr.Warning(errors) gr.Warning(errors)
return return pwd, pwd_cnf
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id)) statement = select(User).where(User.id == int(selected_user_id))
@ -367,11 +416,17 @@ class UserManagement(BasePage):
session.commit() session.commit()
gr.Info(f'User "{usn}" updated successfully') gr.Info(f'User "{usn}" updated successfully')
def delete_user(self, selected_user_id): return "", ""
def delete_user(self, current_user, selected_user_id):
if current_user == selected_user_id:
gr.Warning("You cannot delete yourself")
return selected_user_id
with Session(engine) as session: with Session(engine) as session:
statement = select(User).where(User.id == int(selected_user_id)) statement = select(User).where(User.id == int(selected_user_id))
user = session.exec(statement).one() user = session.exec(statement).one()
session.delete(user) session.delete(user)
session.commit() session.commit()
gr.Info(f'User "{user.username}" deleted successfully') gr.Info(f'User "{user.username}" deleted successfully')
return None, self.selected_panel_false return -1

View File

@ -7,8 +7,10 @@ from ktem.app import BasePage
from ktem.components import reasonings from ktem.components import reasonings
from ktem.db.models import Conversation, engine from ktem.db.models import Conversation, engine
from sqlmodel import Session, select from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
from .chat_panel import ChatPanel from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
from .common import STATE from .common import STATE
from .control import ConversationControl from .control import ConversationControl
from .report import ReportIssue from .report import ReportIssue
@ -26,24 +28,39 @@ class ChatPage(BasePage):
with gr.Column(scale=1): with gr.Column(scale=1):
self.chat_control = ConversationControl(self._app) self.chat_control = ConversationControl(self._app)
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion = ChatSuggestion(self._app)
for index in self._app.index_manager.indices: for index in self._app.index_manager.indices:
index.selector = -1 index.selector = None
index_ui = index.get_selector_component_ui() index_ui = index.get_selector_component_ui()
if not index_ui: if not index_ui:
# the index doesn't have a selector UI component
continue continue
index_ui.unrender() index_ui.unrender() # need to rerender later within Accordion
with gr.Accordion(label=f"{index.name} Index", open=False): with gr.Accordion(label=f"{index.name} Index", open=False):
index_ui.render() index_ui.render()
gr_index = index_ui.as_gradio_component() gr_index = index_ui.as_gradio_component()
if gr_index: if gr_index:
if isinstance(gr_index, list):
index.selector = tuple(
range(
len(self._indices_input),
len(self._indices_input) + len(gr_index),
)
)
self._indices_input.extend(gr_index)
else:
index.selector = len(self._indices_input) index.selector = len(self._indices_input)
self._indices_input.append(gr_index) self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui) setattr(self, f"_index_{index.id}", index_ui)
self.report_issue = ReportIssue(self._app) self.report_issue = ReportIssue(self._app)
with gr.Column(scale=6): with gr.Column(scale=6):
self.chat_panel = ChatPanel(self._app) self.chat_panel = ChatPanel(self._app)
with gr.Column(scale=3): with gr.Column(scale=3):
with gr.Accordion(label="Information panel", open=True): with gr.Accordion(label="Information panel", open=True):
self.info_panel = gr.HTML(elem_id="chat-info-panel") self.info_panel = gr.HTML(elem_id="chat-info-panel")
@ -54,11 +71,24 @@ class ChatPage(BasePage):
self.chat_panel.text_input.submit, self.chat_panel.text_input.submit,
self.chat_panel.submit_btn.click, self.chat_panel.submit_btn.click,
], ],
fn=self.chat_panel.submit_msg, fn=self.submit_msg,
inputs=[self.chat_panel.text_input, self.chat_panel.chatbot], inputs=[
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot], self.chat_panel.text_input,
self.chat_panel.chatbot,
self._app.user_id,
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
],
outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
],
concurrency_limit=20,
show_progress="hidden", show_progress="hidden",
).then( ).success(
fn=self.chat_fn, fn=self.chat_fn,
inputs=[ inputs=[
self.chat_control.conversation_id, self.chat_control.conversation_id,
@ -72,6 +102,7 @@ class ChatPage(BasePage):
self.info_panel, self.info_panel,
self.chat_state, self.chat_state,
], ],
concurrency_limit=20,
show_progress="minimal", show_progress="minimal",
).then( ).then(
fn=self.update_data_source, fn=self.update_data_source,
@ -82,6 +113,7 @@ class ChatPage(BasePage):
] ]
+ self._indices_input, + self._indices_input,
outputs=None, outputs=None,
concurrency_limit=20,
) )
self.chat_panel.regen_btn.click( self.chat_panel.regen_btn.click(
@ -98,6 +130,7 @@ class ChatPage(BasePage):
self.info_panel, self.info_panel,
self.chat_state, self.chat_state,
], ],
concurrency_limit=20,
show_progress="minimal", show_progress="minimal",
).then( ).then(
fn=self.update_data_source, fn=self.update_data_source,
@ -108,6 +141,7 @@ class ChatPage(BasePage):
] ]
+ self._indices_input, + self._indices_input,
outputs=None, outputs=None,
concurrency_limit=20,
) )
self.chat_panel.chatbot.like( self.chat_panel.chatbot.like(
@ -116,7 +150,12 @@ class ChatPage(BasePage):
outputs=None, outputs=None,
) )
self.chat_control.conversation.change( self.chat_control.btn_new.click(
self.chat_control.new_conv,
inputs=self._app.user_id,
outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
show_progress="hidden",
).then(
self.chat_control.select_conv, self.chat_control.select_conv,
inputs=[self.chat_control.conversation], inputs=[self.chat_control.conversation],
outputs=[ outputs=[
@ -124,12 +163,71 @@ class ChatPage(BasePage):
self.chat_control.conversation, self.chat_control.conversation,
self.chat_control.conversation_rn, self.chat_control.conversation_rn,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self.info_panel,
self.chat_state, self.chat_state,
] ]
+ self._indices_input, + self._indices_input,
show_progress="hidden", show_progress="hidden",
) )
self.chat_control.btn_del.click(
lambda id: self.toggle_delete(id),
inputs=[self.chat_control.conversation_id],
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.chat_control.btn_del_conf.click(
self.chat_control.delete_conv,
inputs=[self.chat_control.conversation_id, self._app.user_id],
outputs=[self.chat_control.conversation_id, self.chat_control.conversation],
show_progress="hidden",
).then(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
]
+ self._indices_input,
show_progress="hidden",
).then(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.chat_control.btn_del_cnl.click(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.chat_control.conversation_rn_btn.click(
self.chat_control.rename_conv,
inputs=[
self.chat_control.conversation_id,
self.chat_control.conversation_rn,
self._app.user_id,
],
outputs=[self.chat_control.conversation, self.chat_control.conversation],
show_progress="hidden",
)
self.chat_control.conversation.select(
self.chat_control.select_conv,
inputs=[self.chat_control.conversation],
outputs=[
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
]
+ self._indices_input,
show_progress="hidden",
).then(
lambda: self.toggle_delete(""),
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
)
self.report_issue.report_btn.click( self.report_issue.report_btn.click(
self.report_issue.report, self.report_issue.report,
inputs=[ inputs=[
@ -140,11 +238,77 @@ class ChatPage(BasePage):
self.chat_panel.chatbot, self.chat_panel.chatbot,
self._app.settings_state, self._app.settings_state,
self._app.user_id, self._app.user_id,
self.info_panel,
self.chat_state, self.chat_state,
] ]
+ self._indices_input, + self._indices_input,
outputs=None, outputs=None,
) )
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
self.chat_suggestion.example.select(
self.chat_suggestion.select_example,
outputs=[self.chat_panel.text_input],
show_progress="hidden",
)
def submit_msg(self, chat_input, chat_history, user_id, conv_id, conv_name):
"""Submit a message to the chatbot"""
if not chat_input:
raise ValueError("Input is empty")
if not conv_id:
id_, update = self.chat_control.new_conv(user_id)
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == id_)
name = session.exec(statement).one().name
new_conv_id = id_
conv_update = update
new_conv_name = name
else:
new_conv_id = conv_id
conv_update = gr.update()
new_conv_name = conv_name
return (
"",
chat_history + [(chat_input, None)],
new_conv_id,
conv_update,
new_conv_name,
)
def toggle_delete(self, conv_id):
if conv_id:
return gr.update(visible=False), gr.update(visible=True)
else:
return gr.update(visible=True), gr.update(visible=False)
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.chat_control.reload_conv,
"inputs": [self._app.user_id],
"outputs": [self.chat_control.conversation],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": lambda: self.chat_control.select_conv(""),
"outputs": [
self.chat_control.conversation_id,
self.chat_control.conversation,
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
]
+ self._indices_input,
"show_progress": "hidden",
},
)
def update_data_source(self, convo_id, messages, state, *selecteds): def update_data_source(self, convo_id, messages, state, *selecteds):
"""Update the data source""" """Update the data source"""
@ -154,8 +318,12 @@ class ChatPage(BasePage):
selecteds_ = {} selecteds_ = {}
for index in self._app.index_manager.indices: for index in self._app.index_manager.indices:
if index.selector != -1: if index.selector is None:
continue
if isinstance(index.selector, int):
selecteds_[str(index.id)] = selecteds[index.selector] selecteds_[str(index.id)] = selecteds[index.selector]
else:
selecteds_[str(index.id)] = [selecteds[i] for i in index.selector]
with Session(engine) as session: with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id) statement = select(Conversation).where(Conversation.id == convo_id)
@ -205,8 +373,11 @@ class ChatPage(BasePage):
retrievers = [] retrievers = []
for index in self._app.index_manager.indices: for index in self._app.index_manager.indices:
index_selected = [] index_selected = []
if index.selector != -1: if isinstance(index.selector, int):
index_selected = selecteds[index.selector] index_selected = selecteds[index.selector]
if isinstance(index.selector, tuple):
for i in index.selector:
index_selected.append(selecteds[i])
iretrievers = index.get_retriever_pipelines(settings, index_selected) iretrievers = index.get_retriever_pipelines(settings, index_selected)
retrievers += iretrievers retrievers += iretrievers
@ -250,6 +421,9 @@ class ChatPage(BasePage):
break break
if "output" in response: if "output" in response:
if response["output"] is None:
text = ""
else:
text += response["output"] text += response["output"]
if "evidence" in response: if "evidence" in response:

View File

@ -0,0 +1,26 @@
import gradio as gr
from ktem.app import BasePage
from theflow.settings import settings as flowsettings
class ChatSuggestion(BasePage):
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
chat_samples = getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION_SAMPLES", [])
chat_samples = [[each] for each in chat_samples]
with gr.Accordion(label="Chat Suggestion", open=False) as self.accordion:
self.example = gr.DataFrame(
value=chat_samples,
headers=["Sample"],
interactive=False,
wrap=True,
)
def as_gradio_component(self):
return self.example
def select_example(self, ev: gr.SelectData):
return ev.value

View File

@ -10,6 +10,17 @@ from .common import STATE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def is_conv_name_valid(name):
"""Check if the conversation name is valid"""
errors = []
if len(name) == 0:
errors.append("Name cannot be empty")
elif len(name) > 40:
errors.append("Name cannot be longer than 40 characters")
return "; ".join(errors)
class ConversationControl(BasePage): class ConversationControl(BasePage):
"""Manage conversation""" """Manage conversation"""
@ -28,9 +39,17 @@ class ConversationControl(BasePage):
interactive=True, interactive=True,
) )
with gr.Row(): with gr.Row() as self._new_delete:
self.conversation_new_btn = gr.Button(value="New", min_width=10) self.btn_new = gr.Button(value="New", min_width=10)
self.conversation_del_btn = gr.Button(value="Delete", min_width=10) self.btn_del = gr.Button(value="Delete", min_width=10)
with gr.Row(visible=False) as self._delete_confirm:
self.btn_del_conf = gr.Button(
value="Delete",
variant="primary",
min_width=10,
)
self.btn_del_cnl = gr.Button(value="Cancel", min_width=10)
with gr.Row(): with gr.Row():
self.conversation_rn = gr.Text( self.conversation_rn = gr.Text(
@ -52,48 +71,6 @@ class ConversationControl(BasePage):
# outputs=[current_state], # outputs=[current_state],
# ) # )
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(
name="onSignIn",
definition={
"fn": self.reload_conv,
"inputs": [self._app.user_id],
"outputs": [self.conversation],
"show_progress": "hidden",
},
)
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.reload_conv,
"inputs": [self._app.user_id],
"outputs": [self.conversation],
"show_progress": "hidden",
},
)
def on_register_events(self):
self.conversation_new_btn.click(
self.new_conv,
inputs=self._app.user_id,
outputs=[self.conversation_id, self.conversation],
show_progress="hidden",
)
self.conversation_del_btn.click(
self.delete_conv,
inputs=[self.conversation_id, self._app.user_id],
outputs=[self.conversation_id, self.conversation],
show_progress="hidden",
)
self.conversation_rn_btn.click(
self.rename_conv,
inputs=[self.conversation_id, self.conversation_rn, self._app.user_id],
outputs=[self.conversation, self.conversation],
show_progress="hidden",
)
def load_chat_history(self, user_id): def load_chat_history(self, user_id):
"""Reload chat history""" """Reload chat history"""
options = [] options = []
@ -112,7 +89,7 @@ class ConversationControl(BasePage):
def reload_conv(self, user_id): def reload_conv(self, user_id):
conv_list = self.load_chat_history(user_id) conv_list = self.load_chat_history(user_id)
if conv_list: if conv_list:
return gr.update(value=conv_list[0][1], choices=conv_list) return gr.update(value=None, choices=conv_list)
else: else:
return gr.update(value=None, choices=[]) return gr.update(value=None, choices=[])
@ -133,10 +110,15 @@ class ConversationControl(BasePage):
return id_, gr.update(value=id_, choices=history) return id_, gr.update(value=id_, choices=history)
def delete_conv(self, conversation_id, user_id): def delete_conv(self, conversation_id, user_id):
"""Create new chat""" """Delete the selected conversation"""
if not conversation_id:
gr.Warning("No conversation selected.")
return None, gr.update()
if user_id is None: if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)") gr.Warning("Please sign in first (Settings → User Settings)")
return None, gr.update() return None, gr.update()
with Session(engine) as session: with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id) statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one() result = session.exec(statement).one()
@ -161,6 +143,7 @@ class ConversationControl(BasePage):
name = result.name name = result.name
selected = result.data_source.get("selected", {}) selected = result.data_source.get("selected", {})
chats = result.data_source.get("messages", []) chats = result.data_source.get("messages", [])
info_panel = ""
state = result.data_source.get("state", STATE) state = result.data_source.get("state", STATE)
except Exception as e: except Exception as e:
logger.warning(e) logger.warning(e)
@ -168,22 +151,36 @@ class ConversationControl(BasePage):
name = "" name = ""
selected = {} selected = {}
chats = [] chats = []
info_panel = ""
state = STATE state = STATE
indices = [] indices = []
for index in self._app.index_manager.indices: for index in self._app.index_manager.indices:
# assume that the index has selector # assume that the index has selector
if index.selector == -1: if index.selector is None:
continue continue
if isinstance(index.selector, int):
indices.append(selected.get(str(index.id), [])) indices.append(selected.get(str(index.id), []))
if isinstance(index.selector, tuple):
indices.extend(selected.get(str(index.id), [[]] * len(index.selector)))
return id_, id_, name, chats, state, *indices return id_, id_, name, chats, info_panel, state, *indices
def rename_conv(self, conversation_id, new_name, user_id): def rename_conv(self, conversation_id, new_name, user_id):
"""Rename the conversation""" """Rename the conversation"""
if user_id is None: if user_id is None:
gr.Warning("Please sign in first (Settings → User Settings)") gr.Warning("Please sign in first (Settings → User Settings)")
return gr.update(), "" return gr.update(), ""
if not conversation_id:
gr.Warning("No conversation selected.")
return gr.update(), ""
errors = is_conv_name_valid(new_name)
if errors:
gr.Warning(errors)
return gr.update(), conversation_id
with Session(engine) as session: with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == conversation_id) statement = select(Conversation).where(Conversation.id == conversation_id)
result = session.exec(statement).one() result = session.exec(statement).one()

View File

@ -48,13 +48,19 @@ class ReportIssue(BasePage):
chat_history: list, chat_history: list,
settings: dict, settings: dict,
user_id: Optional[int], user_id: Optional[int],
info_panel: str,
chat_state: dict, chat_state: dict,
*selecteds *selecteds,
): ):
selecteds_ = {} selecteds_ = {}
for index in self._app.index_manager.indices: for index in self._app.index_manager.indices:
if index.selector != -1: if index.selector is not None:
if isinstance(index.selector, int):
selecteds_[str(index.id)] = selecteds[index.selector] selecteds_[str(index.id)] = selecteds[index.selector]
elif isinstance(index.selector, tuple):
selecteds_[str(index.id)] = [selecteds[_] for _ in index.selector]
else:
print(f"Unknown selector type: {index.selector}")
with Session(engine) as session: with Session(engine) as session:
issue = IssueReport( issue = IssueReport(
@ -66,6 +72,7 @@ class ReportIssue(BasePage):
chat={ chat={
"conv_id": conv_id, "conv_id": conv_id,
"chat_history": chat_history, "chat_history": chat_history,
"info_panel": info_panel,
"chat_state": chat_state, "chat_state": chat_state,
"selecteds": selecteds_, "selecteds": selecteds_,
}, },

View File

@ -31,11 +31,10 @@ class LoginPage(BasePage):
self.on_building_ui() self.on_building_ui()
def on_building_ui(self): def on_building_ui(self):
gr.Markdown("Welcome to Kotaemon") gr.Markdown("# Welcome to Kotaemon")
self.usn = gr.Textbox(label="Username") self.usn = gr.Textbox(label="Username", visible=False)
self.pwd = gr.Textbox(label="Password", type="password") self.pwd = gr.Textbox(label="Password", type="password", visible=False)
self.btn_login = gr.Button("Login") self.btn_login = gr.Button("Login", visible=False)
self._dummy = gr.State()
def on_register_events(self): def on_register_events(self):
onSignIn = gr.on( onSignIn = gr.on(
@ -45,24 +44,56 @@ class LoginPage(BasePage):
outputs=[self._app.user_id, self.usn, self.pwd], outputs=[self._app.user_id, self.usn, self.pwd],
show_progress="hidden", show_progress="hidden",
js=signin_js, js=signin_js,
).then(
self.toggle_login_visibility,
inputs=[self._app.user_id],
outputs=[self.usn, self.pwd, self.btn_login],
) )
for event in self._app.get_event("onSignIn"): for event in self._app.get_event("onSignIn"):
onSignIn = onSignIn.success(**event) onSignIn = onSignIn.success(**event)
def toggle_login_visibility(self, user_id):
return (
gr.update(visible=user_id is None),
gr.update(visible=user_id is None),
gr.update(visible=user_id is None),
)
def _on_app_created(self): def _on_app_created(self):
self._app.app.load( onSignIn = self._app.app.load(
None, self.login,
inputs=None, inputs=[self.usn, self.pwd],
outputs=[self.usn, self.pwd], outputs=[self._app.user_id, self.usn, self.pwd],
show_progress="hidden",
js=fetch_creds, js=fetch_creds,
).then(
self.toggle_login_visibility,
inputs=[self._app.user_id],
outputs=[self.usn, self.pwd, self.btn_login],
)
for event in self._app.get_event("onSignIn"):
onSignIn = onSignIn.success(**event)
def on_subscribe_public_events(self):
self._app.subscribe_event(
name="onSignOut",
definition={
"fn": self.toggle_login_visibility,
"inputs": [self._app.user_id],
"outputs": [self.usn, self.pwd, self.btn_login],
"show_progress": "hidden",
},
) )
def login(self, usn, pwd): def login(self, usn, pwd):
if not usn or not pwd:
return None, usn, pwd
hashed_password = hashlib.sha256(pwd.encode()).hexdigest() hashed_password = hashlib.sha256(pwd.encode()).hexdigest()
with Session(engine) as session: with Session(engine) as session:
stmt = select(User).where( stmt = select(User).where(
User.username_lower == usn.lower(), User.password == hashed_password User.username_lower == usn.lower().strip(),
User.password == hashed_password,
) )
result = session.exec(stmt).all() result = session.exec(stmt).all()
if result: if result:

View File

@ -164,9 +164,14 @@ class SettingsPage(BasePage):
show_progress="hidden", show_progress="hidden",
) )
onSignOutClick = self.signout.click( onSignOutClick = self.signout.click(
lambda: (None, "Current user: ___"), lambda: (None, "Current user: ___", "", ""),
inputs=None, inputs=None,
outputs=[self._user_id, self.current_name], outputs=[
self._user_id,
self.current_name,
self.password_change,
self.password_change_confirm,
],
show_progress="hidden", show_progress="hidden",
js=signout_js, js=signout_js,
).then( ).then(
@ -192,8 +197,12 @@ class SettingsPage(BasePage):
self.password_change_btn = gr.Button("Change password", interactive=True) self.password_change_btn = gr.Button("Change password", interactive=True)
def change_password(self, user_id, password, password_confirm): def change_password(self, user_id, password, password_confirm):
if password != password_confirm: from ktem.pages.admin.user import validate_password
gr.Warning("Password does not match")
errors = validate_password(password, password_confirm)
if errors:
print(errors)
gr.Warning(errors)
return password, password_confirm return password, password_confirm
with Session(engine) as session: with Session(engine) as session:

View File

@ -34,12 +34,16 @@ class BaseReasoning(BaseComponent):
@classmethod @classmethod
def get_pipeline( def get_pipeline(
cls, user_settings: dict, retrievers: Optional[list["BaseComponent"]] = None cls,
user_settings: dict,
state: dict,
retrievers: Optional[list["BaseComponent"]] = None,
) -> "BaseReasoning": ) -> "BaseReasoning":
"""Get the reasoning pipeline for the app to execute """Get the reasoning pipeline for the app to execute
Args: Args:
user_setting: user settings user_setting: user settings
state: conversation state
retrievers (list): List of retrievers retrievers (list): List of retrievers
""" """
return cls() return cls()

View File

@ -22,6 +22,8 @@ from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import ChatLLM, PromptTemplate from kotaemon.llms import ChatLLM, PromptTemplate
from kotaemon.loaders.utils.gpt4v import stream_gpt4v from kotaemon.loaders.utils.gpt4v import stream_gpt4v
from .base import BaseReasoning
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EVIDENCE_MODE_TEXT = 0 EVIDENCE_MODE_TEXT = 0
@ -204,7 +206,7 @@ class AnswerWithContextPipeline(BaseComponent):
lang: str = "English" # support English and Japanese lang: str = "English" # support English and Japanese
async def run( # type: ignore async def run( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0 self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Document: ) -> Document:
"""Answer the question based on the evidence """Answer the question based on the evidence
@ -336,7 +338,7 @@ class RewriteQuestionPipeline(BaseComponent):
return Document(text=output) return Document(text=output)
class FullQAPipeline(BaseComponent): class FullQAPipeline(BaseReasoning):
"""Question answering pipeline. Handle from question to answer""" """Question answering pipeline. Handle from question to answer"""
class Config: class Config:
@ -352,6 +354,8 @@ class FullQAPipeline(BaseComponent):
async def run( # type: ignore async def run( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore ) -> Document: # type: ignore
import markdown
docs = [] docs = []
doc_ids = [] doc_ids = []
if self.use_rewrite: if self.use_rewrite:
@ -364,12 +368,16 @@ class FullQAPipeline(BaseComponent):
docs.append(doc) docs.append(doc)
doc_ids.append(doc.doc_id) doc_ids.append(doc.doc_id)
for doc in docs: for doc in docs:
# TODO: a better approach to show the information
text = markdown.markdown(
doc.text, extensions=["markdown.extensions.tables"]
)
self.report_output( self.report_output(
{ {
"evidence": ( "evidence": (
"<details open>" "<details open>"
f"<summary>{doc.metadata['file_name']}</summary>" f"<summary>{doc.metadata['file_name']}</summary>"
f"{doc.text}" f"{text}"
"</details><br>" "</details><br>"
) )
} }
@ -378,7 +386,12 @@ class FullQAPipeline(BaseComponent):
evidence_mode, evidence = self.evidence_pipeline(docs).content evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = await self.answering_pipeline( answer = await self.answering_pipeline(
question=message, evidence=evidence, evidence_mode=evidence_mode question=message,
history=history,
evidence=evidence,
evidence_mode=evidence_mode,
conv_id=conv_id,
**kwargs,
) )
# prepare citation # prepare citation
@ -388,13 +401,28 @@ class FullQAPipeline(BaseComponent):
for quote in fact_with_evidence.substring_quote: for quote in fact_with_evidence.substring_quote:
for doc in docs: for doc in docs:
start_idx = doc.text.find(quote) start_idx = doc.text.find(quote)
if start_idx >= 0: if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append( spans[doc.doc_id].append(
{ {
"start": start_idx, "start": current_idx,
"end": start_idx + len(quote), "end": current_idx + match_idx,
} }
) )
current_idx += match_idx + 2
if current_idx > end_idx:
break
break break
id2docs = {doc.doc_id: doc for doc in docs} id2docs = {doc.doc_id: doc for doc in docs}
@ -414,12 +442,15 @@ class FullQAPipeline(BaseComponent):
if idx < len(ss) - 1: if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :] text += id2docs[id].text[ss[-1]["end"] :]
text_out = markdown.markdown(
text, extensions=["markdown.extensions.tables"]
)
self.report_output( self.report_output(
{ {
"evidence": ( "evidence": (
"<details open>" "<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>" f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text}" f"{text_out}"
"</details><br>" "</details><br>"
) )
} }
@ -434,12 +465,15 @@ class FullQAPipeline(BaseComponent):
{"evidence": "Retrieved segments without matching evidence:\n"} {"evidence": "Retrieved segments without matching evidence:\n"}
) )
for id in list(not_detected): for id in list(not_detected):
text_out = markdown.markdown(
id2docs[id].text, extensions=["markdown.extensions.tables"]
)
self.report_output( self.report_output(
{ {
"evidence": ( "evidence": (
"<details>" "<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>" f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{id2docs[id].text}" f"{text_out}"
"</details><br>" "</details><br>"
) )
} }