Restructure index to allow it to be dynamically created by end-user (#151)

1. Introduce the concept of "collection_name" to docstore and vector store. Each collection can be viewed similarly to a table in a SQL database. It allows better organizing information within this data source.
2. Move the `Index` and `Source` tables from the application scope into the index scope. For each new index created by user, these tables should increase accordingly. So it depends on the index, rather than the app.
3. Make each index responsible for the UI components in the app.
4. Construct the File UI page.
This commit is contained in:
Duc Nguyen (john) 2024-03-07 01:50:47 +07:00 committed by GitHub
parent cc87aaa783
commit 8a90fcfc99
43 changed files with 1658 additions and 812 deletions

View File

@ -0,0 +1,92 @@
The file index stores files in a local folder and index them for retrieval.
This file index provides the following infrastructure to support the indexing:
- SQL table Source: store the list of files that are indexed by the system
- Vector store: contain the embedding of segments of the files
- Document store: contain the text of segments of the files. Each text stored
in this document store is associated with a vector in the vector store.
- SQL table Index: store the relationship between (1) the source and the
docstore, and (2) the source and the vector store.
The indexing and retrieval pipelines are encouraged to use the above software
infrastructure.
## Indexing pipeline
The ktem has default indexing pipeline: `ktem.index.file.pipelines.IndexDocumentPipeline`.
This default pipeline works as follow:
- Input: list of file paths
- Output: list of nodes that are indexed into database
- Process:
- Read files into texts. Different file types has different ways to read
texts.
- Split text files into smaller segments
- Run each segments into embeddings.
- Store the embeddings into vector store. Store the texts of each segment
into docstore. Store the list of files in Source. Store the linking
between Sources and docstore + vectorstore in Index table.
You can customize this default pipeline if your indexing process is close to the
default pipeline. You can create your own indexing pipeline if there are too
much different logic.
### Customize the default pipeline
The default pipeline provides the contact points in `flowsettings.py`.
1. `FILE_INDEX_PIPELINE_FILE_EXTRACTORS`. Supply overriding file extractor,
based on file extension. Example: `{".pdf": "path.to.PDFReader", ".xlsx": "path.to.ExcelReader"}`
2. `FILE_INDEX_PIPELINE_SPLITTER_CHUNK_SIZE`. The expected number of characters
of each text segment. Example: 1024.
3. `FILE_INDEX_PIPELINE_SPLITTER_CHUNK_OVERLAP`. The expected number of
characters that consecutive text segments should overlap with each other.
Example: 256.
### Create your own indexing pipeline
Your indexing pipeline will subclass `BaseFileIndexIndexing`.
You should define the following methods:
- `run(self, file_paths)`: run the indexing given the pipeline
- `get_pipeline(cls, user_settings, index_settings)`: return the
fully-initialized pipeline, ready to be used by ktem.
- `user_settings`: is a dictionary contains user settings (e.g. `{"pdf_mode": True, "num_retrieval": 5}`). You can declare these settings in the `get_user_settings` classmethod. ktem will collect these settings into the app Settings page, and will supply these user settings to your `get_pipeline` method.
- `index_settings`: is a dictionary. Currently it's empty for File Index.
- `get_user_settings`: to declare user settings, eturn a dictionary.
By subclassing `BaseFileIndexIndexing`, You will have access to the following resources:
- `self._Source`: the source table
- `self._Index`: the index table
- `self._VS`: the vector store
- `self._DS`: the docstore
Once you have prepared your pipeline, register it in `flowsettings.py`: `FILE_INDEX_PIPELINE = "<python.path.to.your.pipeline>"`.
## Retrieval pipeline
The ktem has default retrieval pipeline:
`ktem.index.file.pipelines.DocumentRetrievalPipeline`. This pipeline works as
follow:
- Input: user text query & optionally a list of source file ids
- Output: the output segments that match the user text query
- Process:
- If a list of source file ids is given, get the list of vector ids that
associate with those file ids.
- Embed the user text query.
- Query the vector store. Provide a list of vector ids to limit query scope
if the user restrict.
- Return the matched text segments
## Software infrastructure
| Infra | Access | Schema | Ref |
| ---------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------- |
| SQL table Source | self.\_Source | - id (int): id of the source (auto)<br>- name (str): the name of the file<br>- path (str): the path of the file<br>- size (int): the file size in bytes<br>- text_length (int): the number of characters in the file (default 0)<br>- date_created (datetime): the time the file is created (auto) | This is SQLALchemy ORM class. Can consult |
| SQL table Index | self.\_Index | - id (int): id of the index entry (auto)<br>- source_id (int): the id of a file in the Source table<br>- target_id: the id of the segment in docstore or vector store<br>- relation_type (str): if the link is "document" or "vector" | This is SQLAlchemy ORM class |
| Vector store | self.\_VS | - self.\_VS.add: add the list of embeddings to the vector store (optionally associate metadata and ids)<br>- self.\_VS.delete: delete vector entries based on ids<br>- self.\_VS.query: get embeddings based on embeddings. | kotaemon > storages > vectorstores > BaseVectorStore |
| Doc store | self.\_DS | - self.\_DS.add: add the segments to document stores<br>- self.\_DS.get: get the segments based on id<br>- self.\_DS.get_all: get all segments<br>- self.\_DS.delete: delete segments based on id | kotaemon > storages > docstores > base > BaseDocumentStore |

View File

@ -45,3 +45,8 @@ class BaseDocumentStore(ABC):
def delete(self, ids: Union[List[str], str]): def delete(self, ids: Union[List[str], str]):
"""Delete document by id""" """Delete document by id"""
... ...
@abstractmethod
def drop(self):
"""Drop the document store"""
...

View File

@ -12,7 +12,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
def __init__( def __init__(
self, self,
index_name: str = "docstore", collection_name: str = "docstore",
elasticsearch_url: str = "http://localhost:9200", elasticsearch_url: str = "http://localhost:9200",
k1: float = 2.0, k1: float = 2.0,
b: float = 0.75, b: float = 0.75,
@ -27,7 +27,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
) )
self.elasticsearch_url = elasticsearch_url self.elasticsearch_url = elasticsearch_url
self.index_name = index_name self.index_name = collection_name
self.k1 = k1 self.k1 = k1
self.b = b self.b = b
@ -55,9 +55,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
} }
# Create the index with the specified settings and mappings # Create the index with the specified settings and mappings
if not self.client.indices.exists(index=index_name): if not self.client.indices.exists(index=self.index_name):
self.client.indices.create( self.client.indices.create(
index=index_name, mappings=mappings, settings=settings index=self.index_name, mappings=mappings, settings=settings
) )
def add( def add(
@ -164,6 +164,11 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
self.client.delete_by_query(index=self.index_name, body=query) self.client.delete_by_query(index=self.index_name, body=query)
self.client.indices.refresh(index=self.index_name) self.client.indices.refresh(index=self.index_name)
def drop(self):
"""Drop the document store"""
self.client.indices.delete(index=self.index_name)
self.client.indices.refresh(index=self.index_name)
def __persist_flow__(self): def __persist_flow__(self):
return { return {
"index_name": self.index_name, "index_name": self.index_name,

View File

@ -83,3 +83,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
def __persist_flow__(self): def __persist_flow__(self):
return {} return {}
def drop(self):
"""Drop the document store"""
self._store = {}

View File

@ -9,11 +9,15 @@ from .in_memory import InMemoryDocumentStore
class SimpleFileDocumentStore(InMemoryDocumentStore): class SimpleFileDocumentStore(InMemoryDocumentStore):
"""Improve InMemoryDocumentStore by auto saving whenever the corpus is changed""" """Improve InMemoryDocumentStore by auto saving whenever the corpus is changed"""
def __init__(self, path: str | Path): def __init__(self, path: str | Path, collection_name: str = "default"):
super().__init__() super().__init__()
self._path = path self._path = path
if path is not None and Path(path).is_file(): self._collection_name = collection_name
self.load(path)
Path(path).mkdir(parents=True, exist_ok=True)
self._save_path = Path(path) / f"{collection_name}.json"
if self._save_path.is_file():
self.load(self._save_path)
def get(self, ids: Union[List[str], str]) -> List[Document]: def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id""" """Get document by id"""
@ -22,7 +26,7 @@ class SimpleFileDocumentStore(InMemoryDocumentStore):
for doc_id in ids: for doc_id in ids:
if doc_id not in self._store: if doc_id not in self._store:
self.load(self._path) self.load(self._save_path)
break break
return [self._store[doc_id] for doc_id in ids] return [self._store[doc_id] for doc_id in ids]
@ -43,14 +47,22 @@ class SimpleFileDocumentStore(InMemoryDocumentStore):
found in the docstore (default to False) found in the docstore (default to False)
""" """
super().add(docs=docs, ids=ids, **kwargs) super().add(docs=docs, ids=ids, **kwargs)
self.save(self._path) self.save(self._save_path)
def delete(self, ids: Union[List[str], str]): def delete(self, ids: Union[List[str], str]):
"""Delete document by id""" """Delete document by id"""
super().delete(ids=ids) super().delete(ids=ids)
self.save(self._path) self.save(self._save_path)
def drop(self):
"""Drop the document store"""
super().drop()
self._save_path.unlink(missing_ok=True)
def __persist_flow__(self): def __persist_flow__(self):
from theflow.utils.modules import serialize from theflow.utils.modules import serialize
return {"path": serialize(self._path)} return {
"path": serialize(self._path),
"collection_name": self._collection_name,
}

View File

@ -66,6 +66,11 @@ class BaseVectorStore(ABC):
""" """
... ...
@abstractmethod
def drop(self):
"""Drop the vector store"""
...
class LlamaIndexVectorStore(BaseVectorStore): class LlamaIndexVectorStore(BaseVectorStore):
_li_class: type[LIVectorStore | BasePydanticVectorStore] _li_class: type[LIVectorStore | BasePydanticVectorStore]

View File

@ -66,17 +66,9 @@ class ChromaVectorStore(LlamaIndexVectorStore):
""" """
self._client.client.delete(ids=ids) self._client.client.delete(ids=ids)
def delete_collection(self, collection_name: Optional[str] = None): def drop(self):
"""Delete entire collection under specified name from vector stores """Delete entire collection from vector stores"""
self._client.client._client.delete_collection(self._client.client.name)
Args:
collection_name: Name of the collection to delete
"""
# a rather ugly chain call but it do the job of finding
# original chromadb client and call delete_collection() method
if collection_name is None:
collection_name = self._client.client.name
self._client.client._client.delete_collection(collection_name)
def count(self) -> int: def count(self) -> int:
return self._collection.count() return self._collection.count()

View File

@ -53,6 +53,10 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
""" """
self._client = self._client.from_persist_path(persist_path=load_path, fs=fs) self._client = self._client.from_persist_path(persist_path=load_path, fs=fs)
def drop(self):
"""Clear the old data"""
self._data = SimpleVectorStoreData()
def __persist_flow__(self): def __persist_flow__(self):
d = self._data.to_dict() d = self._data.to_dict()
d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}" d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}"

View File

@ -20,6 +20,7 @@ class SimpleFileVectorStore(LlamaIndexVectorStore):
def __init__( def __init__(
self, self,
path: str | Path, path: str | Path,
collection_name: str = "default",
data: Optional[SimpleVectorStoreData] = None, data: Optional[SimpleVectorStoreData] = None,
fs: Optional[fsspec.AbstractFileSystem] = None, fs: Optional[fsspec.AbstractFileSystem] = None,
**kwargs: Any, **kwargs: Any,
@ -27,8 +28,9 @@ class SimpleFileVectorStore(LlamaIndexVectorStore):
"""Initialize params.""" """Initialize params."""
self._data = data or SimpleVectorStoreData() self._data = data or SimpleVectorStoreData()
self._fs = fs or fsspec.filesystem("file") self._fs = fs or fsspec.filesystem("file")
self._collection_name = collection_name
self._path = path self._path = path
self._save_path = Path(path) self._save_path = Path(path) / collection_name
super().__init__( super().__init__(
data=data, data=data,
@ -56,11 +58,16 @@ class SimpleFileVectorStore(LlamaIndexVectorStore):
self._client.persist(str(self._save_path), self._fs) self._client.persist(str(self._save_path), self._fs)
return r return r
def drop(self):
self._data = SimpleVectorStoreData()
self._save_path.unlink(missing_ok=True)
def __persist_flow__(self): def __persist_flow__(self):
d = self._data.to_dict() d = self._data.to_dict()
d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}" d["__type__"] = f"{self._data.__module__}.{self._data.__class__.__qualname__}"
return { return {
"data": d, "data": d,
"collection_name": self._collection_name,
"path": str(self._path), "path": str(self._path),
# "fs": self._fs, # "fs": self._fs,
} }

View File

@ -271,9 +271,7 @@ def test_inmemory_document_store_base_interfaces(tmp_path):
def test_simplefile_document_store_base_interfaces(tmp_path): def test_simplefile_document_store_base_interfaces(tmp_path):
"""Test all interfaces of a a document store""" """Test all interfaces of a a document store"""
path = tmp_path / "store.json" store = SimpleFileDocumentStore(path=tmp_path)
store = SimpleFileDocumentStore(path=path)
docs = [ docs = [
Document(text=f"Sample text {idx}", meta={"meta_key": f"meta_value_{idx}"}) Document(text=f"Sample text {idx}", meta={"meta_key": f"meta_value_{idx}"})
for idx in range(10) for idx in range(10)
@ -315,13 +313,13 @@ def test_simplefile_document_store_base_interfaces(tmp_path):
assert len(store.get_all()) == 17, "Document store should have 17 documents" assert len(store.get_all()) == 17, "Document store should have 17 documents"
# Test save # Test save
assert path.exists(), "File should exist" assert (tmp_path / "default.json").exists(), "File should exist"
# Test load # Test load
store2 = SimpleFileDocumentStore(path=path) store2 = SimpleFileDocumentStore(path=tmp_path)
assert len(store2.get_all()) == 17, "Laded document store should have 17 documents" assert len(store2.get_all()) == 17, "Laded document store should have 17 documents"
os.remove(path) os.remove(tmp_path / "default.json")
@patch( @patch(
@ -329,7 +327,7 @@ def test_simplefile_document_store_base_interfaces(tmp_path):
side_effect=_elastic_search_responses, side_effect=_elastic_search_responses,
) )
def test_elastic_document_store(elastic_api): def test_elastic_document_store(elastic_api):
store = ElasticsearchDocumentStore(index_name="test") store = ElasticsearchDocumentStore(collection_name="test")
docs = [ docs = [
Document(text=f"Sample text {idx}", meta={"meta_key": f"meta_value_{idx}"}) Document(text=f"Sample text {idx}", meta={"meta_key": f"meta_value_{idx}"})

View File

@ -81,7 +81,7 @@ class TestChromaVectorStore:
), "load function does not load data completely" ), "load function does not load data completely"
# test delete collection function # test delete collection function
db2.delete_collection() db2.drop()
# reinit the chroma with the same collection name # reinit the chroma with the same collection name
db2 = ChromaVectorStore(path=str(tmp_path)) db2 = ChromaVectorStore(path=str(tmp_path))
assert ( assert (
@ -133,10 +133,11 @@ class TestSimpleFileVectorStore:
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}] metadatas = [{"a": 1, "b": 2}, {"a": 3, "b": 4}, {"a": 5, "b": 6}]
ids = ["1", "2", "3"] ids = ["1", "2", "3"]
db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json") collection_name = "test_save_load_delete"
db = SimpleFileVectorStore(path=tmp_path, collection_name=collection_name)
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.delete(["3"]) db.delete(["3"])
with open(tmp_path / "test_save_load_delete.json") as f: with open(tmp_path / collection_name) as f:
data = json.load(f) data = json.load(f)
assert ( assert (
"1" and "2" in data["text_id_to_ref_doc_id"] "1" and "2" in data["text_id_to_ref_doc_id"]
@ -144,11 +145,11 @@ class TestSimpleFileVectorStore:
assert ( assert (
"3" not in data["text_id_to_ref_doc_id"] "3" not in data["text_id_to_ref_doc_id"]
), "delete function does not delete data completely" ), "delete function does not delete data completely"
db2 = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json") db2 = SimpleFileVectorStore(path=tmp_path, collection_name=collection_name)
assert db2.get("2") == [ assert db2.get("2") == [
0.4, 0.4,
0.5, 0.5,
0.6, 0.6,
], "load function does not load data completely" ], "load function does not load data completely"
os.remove(tmp_path / "test_save_load_delete.json") os.remove(tmp_path / collection_name)

View File

@ -98,4 +98,12 @@ SETTINGS_REASONING = {
} }
KH_INDEX = "ktem.indexing.file.IndexDocumentPipeline" KH_INDEX_TYPES = ["ktem.index.file.FileIndex"]
KH_INDICES = [
{
"id": 1,
"name": "File",
"config": {},
"index_type": "ktem.index.file.FileIndex",
}
]

View File

@ -1,16 +1,13 @@
from pathlib import Path from pathlib import Path
from typing import Optional
import gradio as gr import gradio as gr
import pluggy import pluggy
from ktem import extension_protocol from ktem import extension_protocol
from ktem.components import reasonings from ktem.components import reasonings
from ktem.exceptions import HookAlreadyDeclared, HookNotDeclared from ktem.exceptions import HookAlreadyDeclared, HookNotDeclared
from ktem.settings import ( from ktem.index import IndexManager
BaseSettingGroup, from ktem.settings import BaseSettingGroup, SettingGroup, SettingReasoningGroup
SettingGroup,
SettingItem,
SettingReasoningGroup,
)
from theflow.settings import settings from theflow.settings import settings
from theflow.utils.modules import import_dotted_string from theflow.utils.modules import import_dotted_string
@ -55,22 +52,26 @@ class BaseApp:
self._callbacks: dict[str, list] = {} self._callbacks: dict[str, list] = {}
self._events: dict[str, list] = {} self._events: dict[str, list] = {}
self.register_indices()
self.register_reasonings()
self.register_extensions() self.register_extensions()
self.register_reasonings()
self.initialize_indices()
self.default_settings.reasoning.finalize() self.default_settings.reasoning.finalize()
self.default_settings.index.finalize() self.default_settings.index.finalize()
self.settings_state = gr.State(self.default_settings.flatten()) self.settings_state = gr.State(self.default_settings.flatten())
self.user_id = gr.State(1 if self.dev_mode else None) self.user_id = gr.State(1 if self.dev_mode else None)
def register_indices(self): def initialize_indices(self):
"""Register the index components from app settings""" """Create the index manager, start indices, and register to app settings"""
index = import_dotted_string(settings.KH_INDEX, safe=False) self.index_manager = IndexManager(self)
user_settings = index().get_user_settings() self.index_manager.on_application_startup()
for key, value in user_settings.items():
self.default_settings.index.settings[key] = SettingItem(**value) for index in self.index_manager.indices:
options = index.get_user_settings()
self.default_settings.index.options[index.id] = BaseSettingGroup(
settings=options
)
def register_reasonings(self): def register_reasonings(self):
"""Register the reasoning components from app settings""" """Register the reasoning components from app settings"""
@ -197,6 +198,27 @@ class BasePage:
def _on_app_created(self): def _on_app_created(self):
"""Called when the app is created""" """Called when the app is created"""
def as_gradio_component(self) -> Optional[gr.components.Component]:
"""Return the gradio components responsible for events
Note: in ideal scenario, this method shouldn't be necessary.
"""
return None
def render(self):
for value in self.__dict__.values():
if isinstance(value, gr.blocks.Block):
value.render()
if isinstance(value, BasePage):
value.render()
def unrender(self):
for value in self.__dict__.values():
if isinstance(value, gr.blocks.Block):
value.unrender()
if isinstance(value, BasePage):
value.unrender()
def declare_public_events(self): def declare_public_events(self):
"""Declare an event for the app""" """Declare an event for the app"""
for event in self.public_events: for event in self.public_events:

View File

@ -24,6 +24,10 @@ footer {
border-radius: 0; border-radius: 0;
} }
.indices-tab {
border: none !important;
}
#chat-tab, #settings-tab, #help-tab { #chat-tab, #settings-tab, #help-tab {
border: none !important; border: none !important;
} }

View File

@ -17,13 +17,21 @@ filestorage_path.mkdir(parents=True, exist_ok=True)
@cache @cache
def get_docstore() -> BaseDocumentStore: def get_docstore(collection_name: str = "default") -> BaseDocumentStore:
return deserialize(settings.KH_DOCSTORE, safe=False) from copy import deepcopy
ds_conf = deepcopy(settings.KH_DOCSTORE)
ds_conf["collection_name"] = collection_name
return deserialize(ds_conf, safe=False)
@cache @cache
def get_vectorstore() -> BaseVectorStore: def get_vectorstore(collection_name: str = "default") -> BaseVectorStore:
return deserialize(settings.KH_VECTORSTORE, safe=False) from copy import deepcopy
vs_conf = deepcopy(settings.KH_VECTORSTORE)
vs_conf["collection_name"] = collection_name
return deserialize(vs_conf, safe=False)
class ModelPool: class ModelPool:

View File

@ -1,65 +1,11 @@
import datetime import datetime
import uuid import uuid
from enum import Enum
from typing import Optional from typing import Optional
from sqlalchemy import JSON, Column from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel from sqlmodel import Field, SQLModel
class BaseSource(SQLModel):
"""The source of the document
Attributes:
id: canonical id to identify the source
name: human-friendly name of the source
path: path to retrieve the source
type: [TODO] to differentiate different types of sources (as each type can be
handled differently)
"""
__table_args__ = {"extend_existing": True}
id: str = Field(
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
)
name: str
path: str
class SourceTargetRelation(str, Enum):
"""The type of relationship between the source and the target, to be used with the
Index table.
Current supported relations:
- document: the target is a document
- vector: the target is a vector
"""
DOCUMENT = "document"
VECTOR = "vector"
class BaseIndex(SQLModel):
"""The index pointing from the source id to the target id
Attributes:
id: canonical id to identify the relationship between the source and the target
source_id: corresponds to Source.id
target_id: corresponds to the id of the indexed and processed entries (e.g.
embedding vector, document...)
relation_type: the type of relationship between the source and the target
(corresponds to SourceTargetRelation)
"""
__table_args__ = {"extend_existing": True}
id: Optional[int] = Field(default=None, primary_key=True, index=True)
source_id: str
target_id: str
relation_type: Optional[SourceTargetRelation] = Field(default=None)
class BaseConversation(SQLModel): class BaseConversation(SQLModel):
"""Store the chat conversation between the user and the bot """Store the chat conversation between the user and the bot

View File

@ -4,18 +4,6 @@ from sqlmodel import SQLModel
from theflow.settings import settings from theflow.settings import settings
from theflow.utils.modules import import_dotted_string from theflow.utils.modules import import_dotted_string
_base_source = (
import_dotted_string(settings.KH_TABLE_SOURCE, safe=False)
if hasattr(settings, "KH_TABLE_SOURCE")
else base_models.BaseSource
)
_base_index = (
import_dotted_string(settings.KH_TABLE_INDEX, safe=False)
if hasattr(settings, "KH_TABLE_INDEX")
else base_models.BaseIndex
)
_base_conv = ( _base_conv = (
import_dotted_string(settings.KH_TABLE_CONV, safe=False) import_dotted_string(settings.KH_TABLE_CONV, safe=False)
if hasattr(settings, "KH_TABLE_CONV") if hasattr(settings, "KH_TABLE_CONV")
@ -41,14 +29,6 @@ _base_issue_report = (
) )
class Source(_base_source, table=True): # type: ignore
"""Record the source of the document"""
class Index(_base_index, table=True): # type: ignore
"""The index pointing from the original id to the target id"""
class Conversation(_base_conv, table=True): # type: ignore class Conversation(_base_conv, table=True): # type: ignore
"""Conversation record""" """Conversation record"""
@ -65,8 +45,5 @@ class IssueReport(_base_issue_report, table=True): # type: ignore
"""Record of issues""" """Record of issues"""
SourceTargetRelation = base_models.SourceTargetRelation
if not getattr(settings, "KH_ENABLE_ALEMBIC", False): if not getattr(settings, "KH_ENABLE_ALEMBIC", False):
SQLModel.metadata.create_all(engine) SQLModel.metadata.create_all(engine)

View File

@ -0,0 +1,3 @@
from .manager import IndexManager
__all__ = ["IndexManager"]

View File

@ -0,0 +1,127 @@
import abc
import logging
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from ktem.app import BasePage
from kotaemon.base import BaseComponent
logger = logging.getLogger(__name__)
class BaseIndex(abc.ABC):
"""The base class for the index
The index is responsible for storing information in a searchable manner, and
retrieving that information.
An application can have multiple indices. For example:
- An index of files locally in the computer
- An index of chat messages on Discord, Slack, etc.
- An index of files stored on Google Drie, Dropbox, etc.
- ...
User can create, delete, and manage the indices in this application. They
can create an index, set it to track a local folder in their computer, and
then the chatbot can search for files in that folder. The user can create
another index to track their chat messages on Discords. And so on.
This class defines the interface for the index. It concerns with:
- Setting up the necessary software infrastructure for the index to work
(e.g. database table, vector store collection, etc.).
- Providing the UI for user interaction with the index, including settings.
Methods:
__init__: initiate any resource definition required for the index to work
(e.g. database table, vector store collection, etc.).
on_create: called only once, when the user creates the index.
on_delete: called only once, when the user deletes the index.
on_start: called when the index starts.
get_selector_component_ui: return the UI component to select the entities in
the Chat page. Called in the ChatUI page.
get_index_page_ui: return the index page UI to manage the entities. Called in
the main application UI page.
get_user_settings: return default user settings. Called only when the app starts
get_admin_settings: return the admin settings. Called only when the user
creates the index (for the admin to customize it). The output will be
stored in the Index's config.
get_indexing_pipeline: return the indexing pipeline when the entities are
populated into the index
get_retriever_pipelines: return the retriever pipelines when the user chat
"""
def __init__(self, app, id, name, config):
self._app = app
self.id = id
self.name = name
self._config = config # admin settings
def on_create(self):
"""Create the index for the first time"""
def on_delete(self):
"""Trigger when the user delete the index"""
def on_start(self):
"""Trigger when the index start
Args:
id (int): the id of the index
name (str): the name of the index
config (dict): the config of the index
"""
def get_selector_component_ui(self) -> Optional["BasePage"]:
"""The UI component to select the entities in the Chat page"""
return None
def get_index_page_ui(self) -> Optional["BasePage"]:
"""The index page UI to manage the entities"""
return None
@classmethod
def get_user_settings(cls) -> dict:
"""Return default user settings. These are the runtime settings.
The settings will be populated in the user settings page. And will be used
when initiating the indexing & retriever pipelines.
Returns:
dict: user settings in the dictionary format of
`ktem.settings.SettingItem`
"""
return {}
@classmethod
def get_admin_settings(cls) -> dict:
"""Return the default admin settings. These are the build-time settings.
The settings will be populated in the admin settings page. And will be used
when initiating the indexing & retriever pipelines.
Returns:
dict: user settings in the dictionary format of
`ktem.settings.SettingItem`
"""
return {}
@abc.abstractmethod
def get_indexing_pipeline(self, settings: dict) -> "BaseComponent":
"""Return the indexing pipeline that populates the entities into the index
Args:
settings: the user settings of the index
Returns:
BaseIndexing: the indexing pipeline
"""
...
def get_retriever_pipelines(
self, settings: dict, selected: Optional[list]
) -> list["BaseComponent"]:
"""Return the retriever pipelines to retrieve the entity from the index"""
return []

View File

@ -0,0 +1,3 @@
from .index import FileIndex
__all__ = ["FileIndex"]

View File

@ -0,0 +1,91 @@
from pathlib import Path
from typing import Optional
from kotaemon.base import BaseComponent
class BaseFileIndexRetriever(BaseComponent):
@classmethod
def get_user_settings(cls) -> dict:
"""Get the user settings for indexing
Returns:
dict: user settings in the dictionary format of
`ktem.settings.SettingItem`
"""
return {}
@classmethod
def get_pipeline(
cls,
user_settings: dict,
index_settings: dict,
selected: Optional[list] = None,
) -> "BaseFileIndexRetriever":
raise NotImplementedError
def set_resources(self, resources: dict):
"""Set the resources for the indexing pipeline
This will setup the tables, the vector store and docstore.
Args:
resources (dict): the resources for the indexing pipeline
"""
self._Source = resources["Source"]
self._Index = resources["Index"]
self._VS = resources["VectorStore"]
self._DS = resources["DocStore"]
class BaseFileIndexIndexing(BaseComponent):
"""The pipeline to index information into the data store
You should define the following method:
- run(self, file_paths): run the indexing given the pipeline
- get_pipeline(cls, user_settings, index_settings): return the
fully-initialized pipeline, ready to be used by ktem.
You will have access to the following resources:
- self._Source: the source table
- self._Index: the index table
- self._VS: the vector store
- self._DS: the docstore
"""
def run(self, file_paths: str | Path | list[str | Path], *args, **kwargs):
"""Run the indexing pipeline
Args:
file_paths (str | Path | list[str | Path]): the file paths to index
"""
raise NotImplementedError
@classmethod
def get_pipeline(
cls, user_settings: dict, index_settings: dict
) -> "BaseFileIndexIndexing":
raise NotImplementedError
@classmethod
def get_user_settings(cls) -> dict:
"""Get the user settings for indexing
Returns:
dict: user settings in the dictionary format of
`ktem.settings.SettingItem`
"""
return {}
def set_resources(self, resources: dict):
"""Set the resources for the indexing pipeline
This will setup the tables, the vector store and docstore.
Args:
resources (dict): the resources for the indexing pipeline
"""
self._Source = resources["Source"]
self._Index = resources["Index"]
self._VS = resources["VectorStore"]
self._DS = resources["DocStore"]

View File

@ -0,0 +1,245 @@
import uuid
from typing import Any, Optional, Type
from ktem.components import get_docstore, get_vectorstore
from ktem.db.engine import engine
from ktem.index.base import BaseIndex
from sqlalchemy import Column, DateTime, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
from .ui import FileIndexPage, FileSelector
class FileIndex(BaseIndex):
"""Index for the uploaded files
The file index stores files in a local folder and index them for retrieval.
This file index provides the following infrastructure to support the indexing:
- SQL table Source: store the list of files that are indexed by the system
- Vector store: contain the embedding of segments of the files
- Document store: contain the text of segments of the files. Each text stored
in this document store is associated with a vector in the vector store.
- SQL table Index: store the relationship between (1) the source and the
docstore, and (2) the source and the vector store.
"""
def __init__(self, app, id: int, name: str, config: dict):
super().__init__(app, id, name, config)
Base = declarative_base()
Source = type(
"Source",
(Base,),
{
"__tablename__": f"index__{self.id}__source",
"id": Column(
String,
primary_key=True,
default=lambda: str(uuid.uuid4()),
unique=True,
),
"name": Column(String, unique=True),
"path": Column(String),
"size": Column(Integer),
"text_length": Column(Integer, default=0),
"date_created": Column(
DateTime(timezone=True), server_default=func.now()
),
},
)
Index = type(
"IndexTable",
(Base,),
{
"__tablename__": f"index__{self.id}__index",
"id": Column(Integer, primary_key=True, autoincrement=True),
"source_id": Column(String),
"target_id": Column(String),
"relation_type": Column(Integer),
},
)
self._db_tables: dict[str, Any] = {"Source": Source, "Index": Index}
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
self._docstore: BaseDocumentStore = get_docstore(f"index_{self.id}")
self._resources = {
"Source": Source,
"Index": Index,
"VectorStore": self._vs,
"DocStore": self._docstore,
}
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
self._setup_indexing_cls()
self._setup_retriever_cls()
self._default_settings: dict[str, dict] = {}
self._setting_mappings: dict[str, dict] = {}
def _setup_indexing_cls(self):
"""Retrieve the indexing class for the file index
There is only one indexing class.
The indexing class will is retrieved from the following order. Stop at the
first order found:
- `FILE_INDEX_PIPELINE` in self._config
- `FILE_INDEX_{id}_PIPELINE` in the flowsettings
- `FILE_INDEX_PIPELINE` in the flowsettings
- The default .pipelines.IndexDocumentPipeline
"""
if "FILE_INDEX_PIPELINE" in self._config:
self._indexing_pipeline_cls = import_dotted_string(
self._config["FILE_INDEX_PIPELINE"]
)
return
if hasattr(flowsettings, f"FILE_INDEX_{self.id}_PIPELINE"):
self._indexing_pipeline_cls = import_dotted_string(
getattr(flowsettings, f"FILE_INDEX_{self.id}_PIPELINE")
)
return
if hasattr(flowsettings, "FILE_INDEX_PIPELINE"):
self._indexing_pipeline_cls = import_dotted_string(
getattr(flowsettings, "FILE_INDEX_PIPELINE")
)
return
from .pipelines import IndexDocumentPipeline
self._indexing_pipeline_cls = IndexDocumentPipeline
def _setup_retriever_cls(self):
"""Retrieve the retriever classes for the file index
There can be multiple retriever classes.
The retriever classes will is retrieved from the following order. Stop at the
first order found:
- `FILE_INDEX_RETRIEVER_PIPELINES` in self._config
- `FILE_INDEX_{id}_RETRIEVER_PIPELINES` in the flowsettings
- `FILE_INDEX_RETRIEVER_PIPELINES` in the flowsettings
- The default .pipelines.DocumentRetrievalPipeline
"""
if "FILE_INDEX_RETRIEVER_PIPELINES" in self._config:
self._retriever_pipeline_cls = [
import_dotted_string(each)
for each in self._config["FILE_INDEX_RETRIEVER_PIPELINES"]
]
return
if hasattr(flowsettings, f"FILE_INDEX_{self.id}_RETRIEVER_PIPELINES"):
self._retriever_pipeline_cls = [
import_dotted_string(each)
for each in getattr(
flowsettings, f"FILE_INDEX_{self.id}_RETRIEVER_PIPELINES"
)
]
return
if hasattr(flowsettings, "FILE_INDEX_RETRIEVER_PIPELINES"):
self._retriever_pipeline_cls = [
import_dotted_string(each)
for each in getattr(flowsettings, "FILE_INDEX_RETRIEVER_PIPELINE")
]
return
from .pipelines import DocumentRetrievalPipeline
self._retriever_pipeline_cls = [DocumentRetrievalPipeline]
def on_create(self):
"""Create the index for the first time
For the file index, this will:
1. Create the index and the source table if not already exists
2. Create the vectorstore
3. Create the docstore
"""
self._resources["Source"].metadata.create_all(engine) # type: ignore
self._resources["Index"].metadata.create_all(engine) # type: ignore
def on_delete(self):
"""Clean up the index when the user delete it"""
self._resources["Source"].__table__.drop(engine) # type: ignore
self._resources["Index"].__table__.drop(engine) # type: ignore
self._vs.drop()
self._docstore.drop()
def get_selector_component_ui(self):
return FileSelector(self._app, self)
def get_index_page_ui(self):
return FileIndexPage(self._app, self)
def get_user_settings(self):
if self._default_settings:
return self._default_settings
settings = {}
settings.update(self._indexing_pipeline_cls.get_user_settings())
for cls in self._retriever_pipeline_cls:
settings.update(cls.get_user_settings())
self._default_settings = settings
return settings
@classmethod
def get_admin_settings(cls):
from ktem.components import embeddings
embedding_default = embeddings.get_default_name()
embedding_choices = list(embeddings.options().keys())
return {
"embedding": {
"name": "Embedding model",
"value": embedding_default,
"component": "dropdown",
"choices": embedding_choices,
}
}
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
"""Define the interface of the indexing pipeline"""
prefix = f"index.options.{self.id}."
stripped_settings = {}
for key, value in settings.items():
if key.startswith(prefix):
stripped_settings[key[len(prefix) :]] = value
else:
stripped_settings[key] = value
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self._config)
obj.set_resources(resources=self._resources)
return obj
def get_retriever_pipelines(
self, settings: dict, selected: Optional[list] = None
) -> list["BaseFileIndexRetriever"]:
prefix = f"index.options.{self.id}."
stripped_settings = {}
for key, value in settings.items():
if key.startswith(prefix):
stripped_settings[key[len(prefix) :]] = value
else:
stripped_settings[key] = value
retrievers = []
for cls in self._retriever_pipeline_cls:
obj = cls.get_pipeline(stripped_settings, self._config, selected)
if obj is None:
continue
obj.set_resources(self._resources)
retrievers.append(obj)
return retrievers

View File

@ -3,19 +3,13 @@ from __future__ import annotations
import shutil import shutil
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from functools import lru_cache
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from ktem.components import ( from ktem.components import embeddings, filestorage_path, llms
embeddings, from ktem.db.models import engine
filestorage_path,
get_docstore,
get_vectorstore,
llms,
)
from ktem.db.models import Index, Source, SourceTargetRelation, engine
from ktem.indexing.base import BaseIndexing, BaseRetriever
from llama_index.vector_stores import ( from llama_index.vector_stores import (
FilterCondition, FilterCondition,
FilterOperator, FilterOperator,
@ -23,64 +17,42 @@ from llama_index.vector_stores import (
MetadataFilters, MetadataFilters,
) )
from llama_index.vector_stores.types import VectorStoreQueryMode from llama_index.vector_stores.types import VectorStoreQueryMode
from sqlmodel import Session, select from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings from theflow.settings import settings
from theflow.utils.modules import import_dotted_string
from kotaemon.base import RetrievedDocument from kotaemon.base import RetrievedDocument
from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests import DocumentIngestor from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
USER_SETTINGS = { from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
"index_parser": {
"name": "Index parser",
"value": "normal",
"choices": [
("PDF text parser", "normal"),
("Mathpix", "mathpix"),
("Advanced ocr", "ocr"),
],
"component": "dropdown",
},
"separate_embedding": {
"name": "Use separate embedding",
"value": False,
"choices": [("Yes", True), ("No", False)],
"component": "dropdown",
},
"num_retrieval": {
"name": "Number of documents to retrieve",
"value": 3,
"component": "number",
},
"retrieval_mode": {
"name": "Retrieval mode",
"value": "vector",
"choices": ["vector", "text", "hybrid"],
"component": "dropdown",
},
"prioritize_table": {
"name": "Prioritize table",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
"mmr": {
"name": "Use MMR",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
"use_reranking": {
"name": "Use reranking",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
}
class DocumentRetrievalPipeline(BaseRetriever): @lru_cache
def dev_settings():
"""Retrieve the developer settings from flowsettings.py"""
file_extractors = {}
if hasattr(settings, "FILE_INDEX_PIPELINE_FILE_EXTRACTORS"):
file_extractors = {
key: import_dotted_string(value, safe=False)
for key, value in settings.FILE_INDEX_PIPELINE_FILE_EXTRACTORS.items()
}
chunk_size = None
if hasattr(settings, "FILE_INDEX_PIPELINE_SPLITTER_CHUNK_SIZE"):
chunk_size = settings.FILE_INDEX_PIPELINE_SPLITTER_CHUNK_SIZE
chunk_overlap = None
if hasattr(settings, "FILE_INDEX_PIPELINE_SPLITTER_CHUNK_OVERLAP"):
chunk_overlap = settings.FILE_INDEX_PIPELINE_SPLITTER_CHUNK_OVERLAP
return file_extractors, chunk_size, chunk_overlap
class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"""Retrieve relevant document """Retrieve relevant document
Args: Args:
@ -93,8 +65,6 @@ class DocumentRetrievalPipeline(BaseRetriever):
""" """
vector_retrieval: VectorRetrieval = VectorRetrieval.withx( vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
doc_store=get_docstore(),
vector_store=get_vectorstore(),
embedding=embeddings.get_default(), embedding=embeddings.get_default(),
) )
reranker: BaseReranking = CohereReranking.withx( reranker: BaseReranking = CohereReranking.withx(
@ -117,15 +87,17 @@ class DocumentRetrievalPipeline(BaseRetriever):
mmr: whether to use mmr to re-rank the documents mmr: whether to use mmr to re-rank the documents
doc_ids: list of document ids to constraint the retrieval doc_ids: list of document ids to constraint the retrieval
""" """
Index = self._Index
kwargs = {} kwargs = {}
if doc_ids: if doc_ids:
with Session(engine) as session: with Session(engine) as session:
stmt = select(Index).where( stmt = select(Index).where(
Index.relation_type == SourceTargetRelation.VECTOR, Index.relation_type == "vector",
Index.source_id.in_(doc_ids), # type: ignore Index.source_id.in_(doc_ids), # type: ignore
) )
results = session.exec(stmt) results = session.execute(stmt)
vs_ids = [r.target_id for r in results.all()] vs_ids = [r[0].target_id for r in results.all()]
kwargs["filters"] = MetadataFilters( kwargs["filters"] = MetadataFilters(
filters=[ filters=[
@ -181,8 +153,73 @@ class DocumentRetrievalPipeline(BaseRetriever):
return docs return docs
@classmethod
def get_user_settings(cls) -> dict:
return {
"separate_embedding": {
"name": "Use separate embedding",
"value": False,
"choices": [("Yes", True), ("No", False)],
"component": "dropdown",
},
"num_retrieval": {
"name": "Number of documents to retrieve",
"value": 3,
"component": "number",
},
"retrieval_mode": {
"name": "Retrieval mode",
"value": "vector",
"choices": ["vector", "text", "hybrid"],
"component": "dropdown",
},
"prioritize_table": {
"name": "Prioritize table",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
"mmr": {
"name": "Use MMR",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
"use_reranking": {
"name": "Use reranking",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
}
class IndexDocumentPipeline(BaseIndexing): @classmethod
def get_pipeline(cls, user_settings, index_settings, selected):
"""Get retriever objects associated with the index
Args:
settings: the settings of the app
kwargs: other arguments
"""
retriever = cls(get_extra_table=user_settings["prioritize_table"])
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore
kwargs = {
".top_k": int(user_settings["num_retrieval"]),
".mmr": user_settings["mmr"],
".doc_ids": selected,
}
retriever.set_run(kwargs, temp=True)
return retriever
def set_resources(self, resources: dict):
super().set_resources(resources)
self.vector_retrieval.vector_store = self._VS
self.vector_retrieval.doc_store = self._DS
class IndexDocumentPipeline(BaseFileIndexIndexing):
"""Store the documents and index the content into vector store and doc store """Store the documents and index the content into vector store and doc store
Args: Args:
@ -191,8 +228,6 @@ class IndexDocumentPipeline(BaseIndexing):
""" """
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx( indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx(
doc_store=get_docstore(),
vector_store=get_vectorstore(),
embedding=embeddings.get_default(), embedding=embeddings.get_default(),
) )
file_ingestor: DocumentIngestor = DocumentIngestor.withx() file_ingestor: DocumentIngestor = DocumentIngestor.withx()
@ -215,6 +250,9 @@ class IndexDocumentPipeline(BaseIndexing):
Returns: Returns:
list of split nodes list of split nodes
""" """
Source = self._Source
Index = self._Index
if not isinstance(file_paths, list): if not isinstance(file_paths, list):
file_paths = [file_paths] file_paths = [file_paths]
@ -231,11 +269,11 @@ class IndexDocumentPipeline(BaseIndexing):
with Session(engine) as session: with Session(engine) as session:
statement = select(Source).where(Source.name == Path(abs_path).name) statement = select(Source).where(Source.name == Path(abs_path).name)
item = session.exec(statement).first() item = session.execute(statement).first()
if item and not reindex: if item and not reindex:
errors.append(Path(abs_path).name) errors.append(Path(abs_path).name)
continue continue
to_index.append(abs_path) to_index.append(abs_path)
@ -245,22 +283,26 @@ class IndexDocumentPipeline(BaseIndexing):
f"{errors}" f"{errors}"
) )
if not to_index:
return [], []
# persist the files to storage # persist the files to storage
for path in to_index: for path in to_index:
shutil.copy(path, filestorage_path / file_to_hash[path]) shutil.copy(path, filestorage_path / file_to_hash[path])
# prepare record info # prepare record info
file_to_source: dict[str, Source] = {} file_to_source: dict = {}
for file_path, file_hash in file_to_hash.items(): for file_path, file_hash in file_to_hash.items():
source = Source(path=file_hash, name=Path(file_path).name) source = Source(
name=Path(file_path).name,
path=file_hash,
size=Path(file_path).stat().st_size,
)
file_to_source[file_path] = source file_to_source[file_path] = source
# extract the files # extract the files
nodes = self.file_ingestor(to_index) nodes = self.file_ingestor(to_index)
print("Extracted", len(to_index), "files into", len(nodes), "nodes") print("Extracted", len(to_index), "files into", len(nodes), "nodes")
for node in nodes:
file_path = str(node.metadata["file_path"])
node.source = file_to_source[file_path].id
# index the files # index the files
print("Indexing the files into vector store") print("Indexing the files into vector store")
@ -277,19 +319,27 @@ class IndexDocumentPipeline(BaseIndexing):
for source in file_to_source.values(): for source in file_to_source.values():
file_ids.append(source.id) file_ids.append(source.id)
for node in nodes:
file_path = str(node.metadata["file_path"])
node.source = str(file_to_source[file_path].id)
file_to_source[file_path].text_length += len(node.text)
session.flush()
session.commit()
with Session(engine) as session: with Session(engine) as session:
for node in nodes: for node in nodes:
index = Index( index = Index(
source_id=node.source, source_id=node.source,
target_id=node.doc_id, target_id=node.doc_id,
relation_type=SourceTargetRelation.DOCUMENT, relation_type="document",
) )
session.add(index) session.add(index)
for node in nodes: for node in nodes:
index = Index( index = Index(
source_id=node.source, source_id=node.source,
target_id=node.doc_id, target_id=node.doc_id,
relation_type=SourceTargetRelation.VECTOR, relation_type="vector",
) )
session.add(index) session.add(index)
session.commit() session.commit()
@ -298,33 +348,38 @@ class IndexDocumentPipeline(BaseIndexing):
print(f"{len(nodes)} nodes are indexed") print(f"{len(nodes)} nodes are indexed")
return nodes, file_ids return nodes, file_ids
def get_user_settings(self) -> dict: @classmethod
return USER_SETTINGS def get_user_settings(cls) -> dict:
return {
"index_parser": {
"name": "Index parser",
"value": "normal",
"choices": [
("PDF text parser", "normal"),
("Mathpix", "mathpix"),
("Advanced ocr", "ocr"),
],
"component": "dropdown",
},
}
@classmethod @classmethod
def get_pipeline(cls, settings) -> "IndexDocumentPipeline": def get_pipeline(cls, user_settings, index_settings) -> "IndexDocumentPipeline":
"""Get the pipeline based on the setting""" """Get the pipeline based on the setting"""
obj = cls() obj = cls()
obj.file_ingestor.pdf_mode = settings["index.index_parser"] obj.file_ingestor.pdf_mode = user_settings["index_parser"]
file_extractors, chunk_size, chunk_overlap = dev_settings()
if file_extractors:
obj.file_ingestor.override_file_extractors = file_extractors
if chunk_size:
obj.file_ingestor.text_splitter.chunk_size = chunk_size
if chunk_overlap:
obj.file_ingestor.text_splitter.chunk_overlap = chunk_overlap
return obj return obj
def get_retrievers(self, settings, **kwargs) -> list[BaseRetriever]: def set_resources(self, resources: dict):
"""Get retriever objects associated with the index super().set_resources(resources)
self.indexing_vector_pipeline.vector_store = self._VS
Args: self.indexing_vector_pipeline.doc_store = self._DS
settings: the settings of the app
kwargs: other arguments
"""
retriever = DocumentRetrievalPipeline(
get_extra_table=settings["index.prioritize_table"]
)
if not settings["index.use_reranking"]:
retriever.reranker = None # type: ignore
kwargs = {
".top_k": int(settings["index.num_retrieval"]),
".mmr": settings["index.mmr"],
".doc_ids": kwargs.get("files", None),
}
retriever.set_run(kwargs, temp=True)
return [retriever]

View File

@ -0,0 +1,469 @@
import os
import tempfile
import gradio as gr
import pandas as pd
from ktem.app import BasePage
from ktem.db.engine import engine
from sqlalchemy import select
from sqlalchemy.orm import Session
class DirectoryUpload(BasePage):
def __init__(self, app):
self._app = app
self._supported_file_types = [
"image",
".pdf",
".txt",
".csv",
".xlsx",
".doc",
".docx",
".pptx",
".html",
".zip",
]
self.on_building_ui()
def on_building_ui(self):
with gr.Accordion(label="Directory upload", open=False):
gr.Markdown(
f"Supported file types: {', '.join(self._supported_file_types)}",
)
self.path = gr.Textbox(
placeholder="Directory path...", lines=1, max_lines=1, container=False
)
with gr.Accordion("Advanced indexing options", open=False):
with gr.Row():
self.reindex = gr.Checkbox(
value=False, label="Force reindex file", container=False
)
self.upload_button = gr.Button("Upload and Index")
self.file_output = gr.File(
visible=False, label="Output files (debug purpose)"
)
class FileIndexPage(BasePage):
def __init__(self, app, index):
super().__init__(app)
self._index = index
self._supported_file_types = [
"image",
".pdf",
".txt",
".csv",
".xlsx",
".doc",
".docx",
".pptx",
".html",
".zip",
]
self.selected_panel_false = "Selected file: (please select above)"
self.selected_panel_true = "Selected file: {name}"
# TODO: on_building_ui is not correctly named if it's always called in
# the constructor
self.public_events = [f"onFileIndex{index.id}Changed"]
self.on_building_ui()
def on_building_ui(self):
"""Build the UI of the app"""
with gr.Accordion(label="File upload", open=False):
gr.Markdown(
f"Supported file types: {', '.join(self._supported_file_types)}",
)
self.files = gr.File(
file_types=self._supported_file_types,
file_count="multiple",
container=False,
)
with gr.Accordion("Advanced indexing options", open=False):
with gr.Row():
self.reindex = gr.Checkbox(
value=False, label="Force reindex file", container=False
)
self.upload_button = gr.Button("Upload and Index")
self.file_output = gr.File(
visible=False, label="Output files (debug purpose)"
)
gr.Markdown("## File list")
self.file_list_state = gr.State(value=None)
self.file_list = gr.DataFrame(
headers=["id", "name", "size", "text_length", "date_created"],
interactive=False,
)
with gr.Row():
self.selected_file_id = gr.State(value=None)
self.selected_panel = gr.Markdown(self.selected_panel_false)
self.deselect_button = gr.Button("Deselect", visible=False)
with gr.Row():
with gr.Column():
self.view_button = gr.Button("View Text (WIP)")
with gr.Column():
self.delete_button = gr.Button("Delete")
with gr.Row():
self.delete_yes = gr.Button("Confirm Delete", visible=False)
self.delete_no = gr.Button("Cancel", visible=False)
def on_subscribe_public_events(self):
"""Subscribe to the declared public event of the app"""
def file_selected(self, file_id):
if file_id is None:
deselect = gr.update(visible=False)
else:
deselect = gr.update(visible=True)
return (
deselect,
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
def to_confirm_delete(self, file_id):
if file_id is None:
gr.Warning("No file is selected")
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
return (
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=True),
)
def delete_yes_event(self, file_id):
with Session(engine) as session:
source = session.execute(
select(self._index._db_tables["Source"]).where(
self._index._db_tables["Source"].id == file_id
)
).first()
if source:
session.delete(source[0])
vs_ids, ds_ids = [], []
index = session.execute(
select(self._index._db_tables["Index"]).where(
self._index._db_tables["Index"].source_id == file_id
)
).all()
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
else:
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
self._index._vs.delete(vs_ids)
self._index._docstore.delete(ds_ids)
gr.Info(f"File {file_id} has been deleted")
return None, self.selected_panel_false
def delete_no_event(self):
return (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
)
def on_register_events(self):
"""Register all events to the app"""
self.delete_button.click(
fn=self.to_confirm_delete,
inputs=[self.selected_file_id],
outputs=[self.delete_button, self.delete_yes, self.delete_no],
show_progress="hidden",
)
onDeleted = (
self.delete_yes.click(
fn=self.delete_yes_event,
inputs=[self.selected_file_id],
outputs=None,
)
.then(
fn=lambda: (None, self.selected_panel_false),
inputs=None,
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
)
.then(
fn=self.list_file,
inputs=None,
outputs=[self.file_list_state, self.file_list],
)
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onDeleted = onDeleted.then(**event)
self.delete_no.click(
fn=self.delete_no_event,
inputs=None,
outputs=[self.delete_button, self.delete_yes, self.delete_no],
show_progress="hidden",
)
self.deselect_button.click(
fn=lambda: (None, self.selected_panel_false),
inputs=None,
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
)
self.selected_panel.change(
fn=self.file_selected,
inputs=[self.selected_file_id],
outputs=[
self.deselect_button,
self.delete_button,
self.delete_yes,
self.delete_no,
],
show_progress="hidden",
)
onUploaded = self.upload_button.click(
fn=self.index_fn,
inputs=[
self.files,
self.reindex,
self._app.settings_state,
],
outputs=[self.file_output],
).then(
fn=self.list_file,
inputs=None,
outputs=[self.file_list_state, self.file_list],
)
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
onUploaded = onUploaded.then(**event)
self.file_list.select(
fn=self.interact_file_list,
inputs=[self.file_list],
outputs=[self.selected_file_id, self.selected_panel],
show_progress="hidden",
)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
self.list_file,
inputs=None,
outputs=[self.file_list_state, self.file_list],
)
def index_fn(self, files, reindex: bool, settings):
"""Upload and index the files
Args:
files: the list of files to be uploaded
reindex: whether to reindex the files
selected_files: the list of files already selected
settings: the settings of the app
"""
gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline
indexing_pipeline = self._index.get_indexing_pipeline(settings)
output_nodes, _ = indexing_pipeline(files, reindex=reindex)
gr.Info(f"Finish indexing into {len(output_nodes)} chunks")
# download the file
text = "\n\n".join([each.text for each in output_nodes])
handler, file_path = tempfile.mkstemp(suffix=".txt")
with open(file_path, "w") as f:
f.write(text)
os.close(handler)
return gr.update(value=file_path, visible=True)
def index_files_from_dir(self, folder_path, reindex, settings):
"""This should be constructable by users
It means that the users can build their own index.
Build your own index:
- Input:
- Type: based on the type, then there are ranges of. Use can select
multiple panels:
- Panels
- Data sources
- Include patterns
- Exclude patterns
- Indexing functions. Can be a list of indexing functions. Each declared
function is:
- Condition (the source that will go through this indexing function)
- Function (the pipeline that run this)
- Output: artifacts that can be used to -> this is the artifacts that we
wish
- Build the UI
- Upload page: fixed standard, based on the type
- Read page: fixed standard, based on the type
- Delete page: fixed standard, based on the type
- Build the index function
- Build the chat function
Step:
1. Decide on the artifacts
2. Implement the transformation from artifacts to UI
"""
if not folder_path:
return
import fnmatch
from pathlib import Path
include_patterns: list[str] = []
exclude_patterns: list[str] = ["*.png", "*.gif", "*/.*"]
if include_patterns and exclude_patterns:
raise ValueError("Cannot have both include and exclude patterns")
# clean up the include patterns
for idx in range(len(include_patterns)):
if include_patterns[idx].startswith("*"):
include_patterns[idx] = str(Path.cwd() / "**" / include_patterns[idx])
else:
include_patterns[idx] = str(
Path.cwd() / include_patterns[idx].strip("/")
)
# clean up the exclude patterns
for idx in range(len(exclude_patterns)):
if exclude_patterns[idx].startswith("*"):
exclude_patterns[idx] = str(Path.cwd() / "**" / exclude_patterns[idx])
else:
exclude_patterns[idx] = str(
Path.cwd() / exclude_patterns[idx].strip("/")
)
# get the files
files: list[str] = [str(p) for p in Path(folder_path).glob("**/*.*")]
if include_patterns:
for p in include_patterns:
files = fnmatch.filter(names=files, pat=p)
if exclude_patterns:
for p in exclude_patterns:
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
return self.index_fn(files, reindex, settings)
def list_file(self):
Source = self._index._db_tables["Source"]
with Session(engine) as session:
statement = select(Source)
results = [
{
"id": each[0].id,
"name": each[0].name,
"size": each[0].size,
"text_length": each[0].text_length,
"date_created": each[0].date_created,
}
for each in session.execute(statement).all()
]
if results:
file_list = pd.DataFrame.from_records(results)
else:
file_list = pd.DataFrame.from_records(
[
{
"id": "-",
"name": "-",
"size": "-",
"text_length": "-",
"date_created": "-",
}
]
)
return results, file_list
def interact_file_list(self, list_files, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No file is uploaded")
return None, self.selected_panel_false
if not ev.selected:
return None, self.selected_panel_false
return list_files["id"][ev.index[0]], self.selected_panel_true.format(
name=list_files["name"][ev.index[0]]
)
def delete(self, file_id):
pass
def cancel_delete(self):
pass
class FileSelector(BasePage):
"""File selector UI in the Chat page"""
def __init__(self, app, index):
super().__init__(app)
self._index = index
self.on_building_ui()
def on_building_ui(self):
self.selector = gr.Dropdown(
label="Files",
choices=[],
multiselect=True,
container=False,
interactive=True,
)
def as_gradio_component(self):
return self.selector
def load_files(self, selected_files):
options = []
available_ids = []
with Session(engine) as session:
statement = select(self._index._db_tables["Source"])
results = session.execute(statement).all()
for result in results:
available_ids.append(result[0].id)
options.append((result[0].name, result[0].id))
if selected_files:
available_ids_set = set(available_ids)
selected_files = [
each for each in selected_files if each in available_ids_set
]
return gr.update(value=selected_files, choices=options)
def _on_app_created(self):
self._app.app.load(
self.load_files,
inputs=self.selector,
outputs=[self.selector],
)
def on_subscribe_public_events(self):
self._app.subscribe_event(
name=f"onFileIndex{self._index.id}Changed",
definition={
"fn": self.load_files,
"inputs": [self.selector],
"outputs": [self.selector],
"show_progress": "hidden",
},
)

View File

@ -0,0 +1,113 @@
from typing import Type
from ktem.db.models import engine
from sqlmodel import Session, select
from theflow.settings import settings
from theflow.utils.modules import import_dotted_string
from .base import BaseIndex
from .models import Index
class IndexManager:
"""Manage the application indices
The index manager is responsible for:
- Managing the range of possible indices and their extensions
- Each actual index built by user
Attributes:
- indices: list of indices built by user
"""
def __init__(self, app):
self._app = app
self._indices = []
self._index_types = {}
def add_index_type(self, cls: Type[BaseIndex]):
"""Register index type to the system"""
self._index_types[cls.__name__] = cls
def list_index_types(self) -> dict:
"""List the index_type of the index"""
return self._index_types
def build_index(self, name: str, config: dict, index_type: str, id=None):
"""Build the index
Building the index simply means recording the index information into the
database and returning the index object.
Args:
name (str): the name of the index
config (dict): the config of the index
index_type (str): the type of the index
id (int, optional): the id of the index. If None, the id will be
generated automatically. Defaults to None.
Returns:
BaseIndex: the index object
"""
with Session(engine) as session:
index_entry = Index(id=id, name=name, config=config, index_type=index_type)
session.add(index_entry)
session.commit()
session.refresh(index_entry)
index_cls = import_dotted_string(index_type, safe=False)
index = index_cls(app=self._app, id=id, name=name, config=config)
index.on_create()
return index
def start_index(self, id: int, name: str, config: dict, index_type: str):
"""Start the index
Args:
id (int): the id of the index
name (str): the name of the index
config (dict): the config of the index
index_type (str): the type of the index
"""
index_cls = import_dotted_string(index_type, safe=False)
index = index_cls(app=self._app, id=id, name=name, config=config)
index.on_start()
self._indices.append(index)
return index
def exists(self, id: int) -> bool:
"""Check if the index exists
Args:
id (int): the id of the index
Returns:
bool: True if the index exists, False otherwise
"""
with Session(engine) as session:
index = session.get(Index, id)
return index is not None
def on_application_startup(self):
"""This method is called by the base application when the application starts
Load the index from database
"""
for index in settings.KH_INDEX_TYPES:
index_cls = import_dotted_string(index, safe=False)
self.add_index_type(index_cls)
for index in settings.KH_INDICES:
if not self.exists(index["id"]):
self.build_index(**index)
with Session(engine) as session:
index_defs = session.exec(select(Index))
for index_def in index_defs:
self.start_index(**index_def.dict())
@property
def indices(self):
return self._indices

View File

@ -0,0 +1,19 @@
from typing import Optional
from ktem.db.engine import engine
from sqlalchemy import JSON, Column
from sqlmodel import Field, SQLModel
# TODO: simplify with using SQLAlchemy directly
class Index(SQLModel, table=True):
__table_args__ = {"extend_existing": True}
__tablename__ = "ktem__index" # type: ignore
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(unique=True)
index_type: str = Field()
config: dict = Field(default={}, sa_column=Column(JSON))
Index.metadata.create_all(engine)

View File

@ -1,25 +0,0 @@
from kotaemon.base import BaseComponent
class BaseRetriever(BaseComponent):
pass
class BaseIndexing(BaseComponent):
"""The pipeline to index information into the data store"""
def get_user_settings(self) -> dict:
"""Get the user settings for indexing
Returns:
dict: user settings in the dictionary format of
`ktem.settings.SettingItem`
"""
return {}
@classmethod
def get_pipeline(cls, settings: dict) -> "BaseIndexing":
raise NotImplementedError
def get_retrievers(self, settings: dict, **kwargs) -> list[BaseRetriever]:
raise NotImplementedError

View File

@ -24,6 +24,15 @@ class App(BaseApp):
with gr.Tab("Chat", elem_id="chat-tab"): with gr.Tab("Chat", elem_id="chat-tab"):
self.chat_page = ChatPage(self) self.chat_page = ChatPage(self)
for index in self.index_manager.indices:
with gr.Tab(
f"{index.name} Index",
elem_id=f"{index.id}-tab",
elem_classes="indices-tab",
):
page = index.get_index_page_ui()
setattr(self, f"_index_{index.id}", page)
with gr.Tab("Settings", elem_id="settings-tab"): with gr.Tab("Settings", elem_id="settings-tab"):
self.settings_page = SettingsPage(self) self.settings_page = SettingsPage(self)

View File

@ -1,33 +1,44 @@
import asyncio
from copy import deepcopy
from typing import Optional
import gradio as gr import gradio as gr
from ktem.app import BasePage from ktem.app import BasePage
from ktem.components import reasonings
from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
from .chat_panel import ChatPanel from .chat_panel import ChatPanel
from .control import ConversationControl from .control import ConversationControl
from .data_source import DataSource
from .events import (
chat_fn,
index_files_from_dir,
index_fn,
is_liked,
load_files,
update_data_source,
)
from .report import ReportIssue from .report import ReportIssue
from .upload import DirectoryUpload, FileUpload
class ChatPage(BasePage): class ChatPage(BasePage):
def __init__(self, app): def __init__(self, app):
self._app = app self._app = app
self._indices_input = []
self.on_building_ui() self.on_building_ui()
def on_building_ui(self): def on_building_ui(self):
with gr.Row(): with gr.Row():
with gr.Column(scale=1): with gr.Column(scale=1):
self.chat_control = ConversationControl(self._app) self.chat_control = ConversationControl(self._app)
self.data_source = DataSource(self._app)
self.file_upload = FileUpload(self._app) for index in self._app.index_manager.indices:
self.dir_upload = DirectoryUpload(self._app) index.selector = -1
index_ui = index.get_selector_component_ui()
if not index_ui:
continue
index_ui.unrender()
with gr.Accordion(label=f"{index.name} Index", open=False):
index_ui.render()
gr_index = index_ui.as_gradio_component()
if gr_index:
index.selector = len(self._indices_input)
self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)
self.report_issue = ReportIssue(self._app) self.report_issue = ReportIssue(self._app)
with gr.Column(scale=6): with gr.Column(scale=6):
self.chat_panel = ChatPanel(self._app) self.chat_panel = ChatPanel(self._app)
@ -36,19 +47,23 @@ class ChatPage(BasePage):
self.info_panel = gr.HTML(elem_id="chat-info-panel") self.info_panel = gr.HTML(elem_id="chat-info-panel")
def on_register_events(self): def on_register_events(self):
self.chat_panel.submit_btn.click( gr.on(
self.chat_panel.submit_msg, triggers=[
self.chat_panel.text_input.submit,
self.chat_panel.submit_btn.click,
],
fn=self.chat_panel.submit_msg,
inputs=[self.chat_panel.text_input, self.chat_panel.chatbot], inputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot], outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
show_progress="hidden", show_progress="hidden",
).then( ).then(
fn=chat_fn, fn=self.chat_fn,
inputs=[ inputs=[
self.chat_control.conversation_id, self.chat_control.conversation_id,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self.data_source.files,
self._app.settings_state, self._app.settings_state,
], ]
+ self._indices_input,
outputs=[ outputs=[
self.chat_panel.text_input, self.chat_panel.text_input,
self.chat_panel.chatbot, self.chat_panel.chatbot,
@ -56,46 +71,17 @@ class ChatPage(BasePage):
], ],
show_progress="minimal", show_progress="minimal",
).then( ).then(
fn=update_data_source, fn=self.update_data_source,
inputs=[
self.chat_control.conversation_id,
self.data_source.files,
self.chat_panel.chatbot,
],
outputs=None,
)
self.chat_panel.text_input.submit(
self.chat_panel.submit_msg,
inputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
outputs=[self.chat_panel.text_input, self.chat_panel.chatbot],
show_progress="hidden",
).then(
fn=chat_fn,
inputs=[ inputs=[
self.chat_control.conversation_id, self.chat_control.conversation_id,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self.data_source.files, ]
self._app.settings_state, + self._indices_input,
],
outputs=[
self.chat_panel.text_input,
self.chat_panel.chatbot,
self.info_panel,
],
show_progress="minimal",
).then(
fn=update_data_source,
inputs=[
self.chat_control.conversation_id,
self.data_source.files,
self.chat_panel.chatbot,
],
outputs=None, outputs=None,
) )
self.chat_panel.chatbot.like( self.chat_panel.chatbot.like(
fn=is_liked, fn=self.is_liked,
inputs=[self.chat_control.conversation_id], inputs=[self.chat_control.conversation_id],
outputs=None, outputs=None,
) )
@ -107,9 +93,9 @@ class ChatPage(BasePage):
self.chat_control.conversation_id, self.chat_control.conversation_id,
self.chat_control.conversation, self.chat_control.conversation,
self.chat_control.conversation_rn, self.chat_control.conversation_rn,
self.data_source.files,
self.chat_panel.chatbot, self.chat_panel.chatbot,
], ]
+ self._indices_input,
show_progress="hidden", show_progress="hidden",
) )
@ -121,47 +107,112 @@ class ChatPage(BasePage):
self.report_issue.more_detail, self.report_issue.more_detail,
self.chat_control.conversation_id, self.chat_control.conversation_id,
self.chat_panel.chatbot, self.chat_panel.chatbot,
self.data_source.files,
self._app.settings_state, self._app.settings_state,
self._app.user_id, self._app.user_id,
], ]
+ self._indices_input,
outputs=None, outputs=None,
) )
self.data_source.files.input( def update_data_source(self, convo_id, messages, *selecteds):
fn=update_data_source, """Update the data source"""
inputs=[ if not convo_id:
self.chat_control.conversation_id, gr.Warning("No conversation selected")
self.data_source.files, return
self.chat_panel.chatbot,
],
outputs=None,
)
self.file_upload.upload_button.click( selecteds_ = {}
fn=index_fn, for index in self._app.index_manager.indices:
inputs=[ if index.selector != -1:
self.file_upload.files, selecteds_[str(index.id)] = selecteds[index.selector]
self.file_upload.reindex,
self.data_source.files,
self._app.settings_state,
],
outputs=[self.file_upload.file_output, self.data_source.files],
)
self.dir_upload.upload_button.click( with Session(engine) as session:
fn=index_files_from_dir, statement = select(Conversation).where(Conversation.id == convo_id)
inputs=[ result = session.exec(statement).one()
self.dir_upload.path,
self.dir_upload.reindex,
self.data_source.files,
self._app.settings_state,
],
outputs=[self.dir_upload.file_output, self.data_source.files],
)
self._app.app.load( data_source = result.data_source
lambda: gr.update(choices=load_files()), result.data_source = {
inputs=None, "selected": selecteds_,
outputs=[self.data_source.files], "messages": messages,
) "likes": deepcopy(data_source.get("likes", [])),
}
session.add(result)
session.commit()
def is_liked(self, convo_id, liked: gr.LikeData):
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
result = session.exec(statement).one()
data_source = deepcopy(result.data_source)
likes = data_source.get("likes", [])
likes.append([liked.index, liked.value, liked.liked])
data_source["likes"] = likes
result.data_source = data_source
session.add(result)
session.commit()
def create_pipeline(self, settings: dict, *selecteds):
"""Create the pipeline from settings
Args:
settings: the settings of the app
selected: the list of file ids that will be served as context. If None, then
consider using all files
Returns:
the pipeline objects
"""
# get retrievers
retrievers = []
for index in self._app.index_manager.indices:
index_selected = []
if index.selector != -1:
index_selected = selecteds[index.selector]
iretrievers = index.get_retriever_pipelines(settings, index_selected)
retrievers += iretrievers
reasoning_mode = settings["reasoning.use"]
reasoning_cls = reasonings[reasoning_mode]
pipeline = reasoning_cls.get_pipeline(settings, retrievers)
return pipeline
async def chat_fn(self, conversation_id, chat_history, settings, *selecteds):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
# construct the pipeline
pipeline = self.create_pipeline(settings, *selecteds)
pipeline.set_output_queue(queue)
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
text, refs = "", ""
len_ref = -1 # for logging purpose
while True:
try:
response = queue.get_nowait()
except Exception:
yield "", chat_history + [(chat_input, text or "Thinking ...")], refs
continue
if response is None:
queue.task_done()
print("Chat completed")
break
if "output" in response:
text += response["output"]
if "evidence" in response:
refs += response["evidence"]
if len(refs) > len_ref:
print(f"Len refs: {len(refs)}")
len_ref = len(refs)
yield "", chat_history + [(chat_input, text)], refs

View File

@ -166,15 +166,23 @@ class ConversationControl(BasePage):
result = session.exec(statement).one() result = session.exec(statement).one()
id_ = result.id id_ = result.id
name = result.name name = result.name
files = result.data_source.get("files", []) selected = result.data_source.get("selected", {})
chats = result.data_source.get("messages", []) chats = result.data_source.get("messages", [])
except Exception as e: except Exception as e:
logger.warning(e) logger.warning(e)
id_ = "" id_ = ""
name = "" name = ""
files = [] selected = {}
chats = [] chats = []
return id_, id_, name, files, chats
indices = []
for index in self._app.index_manager.indices:
# assume that the index has selector
if index.selector == -1:
continue
indices.append(selected.get(str(index.id), []))
return id_, id_, name, chats, *indices
def rename_conv(self, conversation_id, new_name, user_id): def rename_conv(self, conversation_id, new_name, user_id):
"""Rename the conversation""" """Rename the conversation"""

View File

@ -1,18 +0,0 @@
import gradio as gr
from ktem.app import BasePage
class DataSource(BasePage):
def __init__(self, app):
self._app = app
self.on_building_ui()
def on_building_ui(self):
with gr.Accordion(label="Data source", open=True):
self.files = gr.Dropdown(
label="Files",
choices=[],
multiselect=True,
container=False,
interactive=True,
)

View File

@ -1,305 +0,0 @@
import asyncio
import os
import tempfile
from copy import deepcopy
from typing import Optional, Type
import gradio as gr
from ktem.components import llms, reasonings
from ktem.db.models import Conversation, Source, engine
from ktem.indexing.base import BaseIndexing
from sqlmodel import Session, select
from theflow.settings import settings as app_settings
from theflow.utils.modules import import_dotted_string
def create_pipeline(settings: dict, files: Optional[list] = None):
"""Create the pipeline from settings
Args:
settings: the settings of the app
files: the list of file ids that will be served as context. If None, then
consider using all files
Returns:
the pipeline objects
"""
# get retrievers
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
retrievers = indexing_cls.get_pipeline(settings).get_retrievers(
settings, files=files
)
reasoning_mode = settings["reasoning.use"]
reasoning_cls = reasonings[reasoning_mode]
pipeline = reasoning_cls.get_pipeline(settings, retrievers, files=files)
if settings["reasoning.use"] in ["rewoo", "react"]:
from kotaemon.agents import ReactAgent, RewooAgent
llm = (
llms["gpt4"]
if settings["answer_simple_llm_model"] == "gpt-4"
else llms["gpt35"]
)
tools = []
tools_keys = (
"answer_rewoo_tools"
if settings["reasoning.use"] == "rewoo"
else "answer_react_tools"
)
for tool in settings[tools_keys]:
if tool == "llm":
from kotaemon.agents import LLMTool
tools.append(LLMTool(llm=llm))
# elif tool == "docsearch":
# pass
# filenames = ""
# if files:
# with Session(engine) as session:
# statement = select(Source).where(
# Source.id.in_(files) # type: ignore
# )
# results = session.exec(statement).all()
# filenames = (
# "The file names are: "
# + " ".join([result.name for result in results])
# + ". "
# )
# tool = ComponentTool(
# name="docsearch",
# description=(
# "A vector store that searches for similar and "
# "related content "
# f"in a document. {filenames}"
# "The result is a huge chunk of text related "
# "to your search but can also "
# "contain irrelevant info."
# ),
# component=retrieval_pipeline,
# postprocessor=lambda docs: "\n\n".join(
# [doc.text.replace("\n", " ") for doc in docs]
# ),
# )
# tools.append(tool)
elif tool == "google":
from kotaemon.agents import GoogleSearchTool
tools.append(GoogleSearchTool())
elif tool == "wikipedia":
from kotaemon.agents import WikipediaTool
tools.append(WikipediaTool())
else:
raise NotImplementedError(f"Unknown tool: {tool}")
if settings["reasoning.use"] == "rewoo":
pipeline = RewooAgent(
planner_llm=llm,
solver_llm=llm,
plugins=tools,
)
pipeline.set_run({".use_citation": True})
else:
pipeline = ReactAgent(
llm=llm,
plugins=tools,
)
return pipeline
async def chat_fn(conversation_id, chat_history, files, settings):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
# construct the pipeline
pipeline = create_pipeline(settings, files)
pipeline.set_output_queue(queue)
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
text, refs = "", ""
len_ref = -1 # for logging purpose
while True:
try:
response = queue.get_nowait()
except Exception:
yield "", chat_history + [(chat_input, text or "Thinking ...")], refs
continue
if response is None:
queue.task_done()
print("Chat completed")
break
if "output" in response:
text += response["output"]
if "evidence" in response:
refs += response["evidence"]
if len(refs) > len_ref:
print(f"Len refs: {len(refs)}")
len_ref = len(refs)
yield "", chat_history + [(chat_input, text)], refs
def is_liked(convo_id, liked: gr.LikeData):
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
result = session.exec(statement).one()
data_source = deepcopy(result.data_source)
likes = data_source.get("likes", [])
likes.append([liked.index, liked.value, liked.liked])
data_source["likes"] = likes
result.data_source = data_source
session.add(result)
session.commit()
def update_data_source(convo_id, selected_files, messages):
"""Update the data source"""
if not convo_id:
gr.Warning("No conversation selected")
return
with Session(engine) as session:
statement = select(Conversation).where(Conversation.id == convo_id)
result = session.exec(statement).one()
data_source = result.data_source
result.data_source = {
"files": selected_files,
"messages": messages,
"likes": deepcopy(data_source.get("likes", [])),
}
session.add(result)
session.commit()
def load_files():
options = []
with Session(engine) as session:
statement = select(Source)
results = session.exec(statement).all()
for result in results:
options.append((result.name, result.id))
return options
def index_fn(files, reindex: bool, selected_files, settings):
"""Upload and index the files
Args:
files: the list of files to be uploaded
reindex: whether to reindex the files
selected_files: the list of files already selected
settings: the settings of the app
"""
gr.Info(f"Start indexing {len(files)} files...")
# get the pipeline
indexing_cls: Type[BaseIndexing] = import_dotted_string(
app_settings.KH_INDEX, safe=False
)
indexing_pipeline = indexing_cls.get_pipeline(settings)
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)
gr.Info(f"Finish indexing into {len(output_nodes)} chunks")
# download the file
text = "\n\n".join([each.text for each in output_nodes])
handler, file_path = tempfile.mkstemp(suffix=".txt")
with open(file_path, "w") as f:
f.write(text)
os.close(handler)
if isinstance(selected_files, list):
output = selected_files + file_ids
else:
output = file_ids
file_list = load_files()
return (
gr.update(value=file_path, visible=True),
gr.update(value=output, choices=file_list), # unnecessary
)
def index_files_from_dir(folder_path, reindex, selected_files, settings):
"""This should be constructable by users
It means that the users can build their own index.
Build your own index:
- Input:
- Type: based on the type, then there are ranges of. Use can select multiple
panels:
- Panels
- Data sources
- Include patterns
- Exclude patterns
- Indexing functions. Can be a list of indexing functions. Each declared
function is:
- Condition (the source that will go through this indexing function)
- Function (the pipeline that run this)
- Output: artifacts that can be used to -> this is the artifacts that we wish
- Build the UI
- Upload page: fixed standard, based on the type
- Read page: fixed standard, based on the type
- Delete page: fixed standard, based on the type
- Build the index function
- Build the chat function
Step:
1. Decide on the artifacts
2. Implement the transformation from artifacts to UI
"""
if not folder_path:
return
import fnmatch
from pathlib import Path
include_patterns: list[str] = []
exclude_patterns: list[str] = ["*.png", "*.gif", "*/.*"]
if include_patterns and exclude_patterns:
raise ValueError("Cannot have both include and exclude patterns")
# clean up the include patterns
for idx in range(len(include_patterns)):
if include_patterns[idx].startswith("*"):
include_patterns[idx] = str(Path.cwd() / "**" / include_patterns[idx])
else:
include_patterns[idx] = str(Path.cwd() / include_patterns[idx].strip("/"))
# clean up the exclude patterns
for idx in range(len(exclude_patterns)):
if exclude_patterns[idx].startswith("*"):
exclude_patterns[idx] = str(Path.cwd() / "**" / exclude_patterns[idx])
else:
exclude_patterns[idx] = str(Path.cwd() / exclude_patterns[idx].strip("/"))
# get the files
files: list[str] = [str(p) for p in Path(folder_path).glob("**/*.*")]
if include_patterns:
for p in include_patterns:
files = fnmatch.filter(names=files, pat=p)
if exclude_patterns:
for p in exclude_patterns:
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
return index_fn(files, reindex, selected_files, settings)

View File

@ -46,10 +46,15 @@ class ReportIssue(BasePage):
more_detail: str, more_detail: str,
conv_id: str, conv_id: str,
chat_history: list, chat_history: list,
files: list,
settings: dict, settings: dict,
user_id: Optional[int], user_id: Optional[int],
*selecteds
): ):
selecteds_ = {}
for index in self._app.index_manager.indices:
if index.selector != -1:
selecteds_[str(index.id)] = selecteds[index.selector]
with Session(engine) as session: with Session(engine) as session:
issue = IssueReport( issue = IssueReport(
issues={ issues={
@ -60,7 +65,7 @@ class ReportIssue(BasePage):
chat={ chat={
"conv_id": conv_id, "conv_id": conv_id,
"chat_history": chat_history, "chat_history": chat_history,
"files": files, "selecteds": selecteds_,
}, },
settings=settings, settings=settings,
user=user_id, user=user_id,

View File

@ -1,79 +0,0 @@
import gradio as gr
from ktem.app import BasePage
class FileUpload(BasePage):
def __init__(self, app):
self._app = app
self._supported_file_types = [
"image",
".pdf",
".txt",
".csv",
".xlsx",
".doc",
".docx",
".pptx",
".html",
".zip",
]
self.on_building_ui()
def on_building_ui(self):
with gr.Accordion(label="File upload", open=False):
gr.Markdown(
f"Supported file types: {', '.join(self._supported_file_types)}",
)
self.files = gr.File(
file_types=self._supported_file_types,
file_count="multiple",
container=False,
height=50,
)
with gr.Accordion("Advanced indexing options", open=False):
with gr.Row():
self.reindex = gr.Checkbox(
value=False, label="Force reindex file", container=False
)
self.upload_button = gr.Button("Upload and Index")
self.file_output = gr.File(
visible=False, label="Output files (debug purpose)"
)
class DirectoryUpload(BasePage):
def __init__(self, app):
self._app = app
self._supported_file_types = [
"image",
".pdf",
".txt",
".csv",
".xlsx",
".doc",
".docx",
".pptx",
".html",
".zip",
]
self.on_building_ui()
def on_building_ui(self):
with gr.Accordion(label="Directory upload", open=False):
gr.Markdown(
f"Supported file types: {', '.join(self._supported_file_types)}",
)
self.path = gr.Textbox(
placeholder="Directory path...", lines=1, max_lines=1, container=False
)
with gr.Accordion("Advanced indexing options", open=False):
with gr.Row():
self.reindex = gr.Checkbox(
value=False, label="Force reindex file", container=False
)
self.upload_button = gr.Button("Upload and Index")
self.file_output = gr.File(
visible=False, label="Output files (debug purpose)"
)

View File

@ -329,9 +329,17 @@ class SettingsPage(BasePage):
self._components[f"application.{n}"] = obj self._components[f"application.{n}"] = obj
def index_tab(self): def index_tab(self):
for n, si in self._default_settings.index.settings.items(): # TODO: double check if we need general
obj = render_setting_item(si, si.value) # with gr.Tab("General"):
self._components[f"index.{n}"] = obj # for n, si in self._default_settings.index.settings.items():
# obj = render_setting_item(si, si.value)
# self._components[f"index.{n}"] = obj
for pn, sig in self._default_settings.index.options.items():
with gr.Tab(f"Index {pn}"):
for n, si in sig.settings.items():
obj = render_setting_item(si, si.value)
self._components[f"index.options.{pn}.{n}"] = obj
def reasoning_tab(self): def reasoning_tab(self):
with gr.Group(): with gr.Group():

View File

@ -0,0 +1,5 @@
from kotaemon.base import BaseComponent
class BaseReasoning(BaseComponent):
retrievers: list = []

View File

@ -5,7 +5,7 @@ from functools import partial
import tiktoken import tiktoken
from ktem.components import llms from ktem.components import llms
from ktem.indexing.base import BaseRetriever from ktem.reasoning.base import BaseReasoning
from kotaemon.base import ( from kotaemon.base import (
BaseComponent, BaseComponent,
@ -210,20 +210,25 @@ class AnswerWithContextPipeline(BaseComponent):
self.report_output({"output": text.text}) self.report_output({"output": text.text})
await asyncio.sleep(0) await asyncio.sleep(0)
citation = self.citation_pipeline(context=evidence, question=question) try:
citation = self.citation_pipeline(context=evidence, question=question)
except Exception as e:
print(e)
citation = None
answer = Document(text=output, metadata={"citation": citation}) answer = Document(text=output, metadata={"citation": citation})
return answer return answer
class FullQAPipeline(BaseComponent): class FullQAPipeline(BaseReasoning):
"""Question answering pipeline. Handle from question to answer""" """Question answering pipeline. Handle from question to answer"""
class Config: class Config:
allow_extra = True allow_extra = True
params_publish = True
retrievers: list[BaseRetriever] retrievers: list[BaseComponent]
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx() evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx() answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
@ -244,18 +249,19 @@ class FullQAPipeline(BaseComponent):
# prepare citation # prepare citation
spans = defaultdict(list) spans = defaultdict(list)
for fact_with_evidence in answer.metadata["citation"].answer: if answer.metadata["citation"] is not None:
for quote in fact_with_evidence.substring_quote: for fact_with_evidence in answer.metadata["citation"].answer:
for doc in docs: for quote in fact_with_evidence.substring_quote:
start_idx = doc.text.find(quote) for doc in docs:
if start_idx >= 0: start_idx = doc.text.find(quote)
spans[doc.doc_id].append( if start_idx >= 0:
{ spans[doc.doc_id].append(
"start": start_idx, {
"end": start_idx + len(quote), "start": start_idx,
} "end": start_idx + len(quote),
) }
break )
break
id2docs = {doc.doc_id: doc for doc in docs} id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True lack_evidence = True
@ -308,7 +314,7 @@ class FullQAPipeline(BaseComponent):
return answer return answer
@classmethod @classmethod
def get_pipeline(cls, settings, retrievers, **kwargs): def get_pipeline(cls, settings, retrievers):
"""Get the reasoning pipeline """Get the reasoning pipeline
Args: Args:

View File

@ -115,13 +115,6 @@ class SettingIndexGroup(BaseSettingGroup):
return output return output
def finalize(self):
"""Finalize the setting"""
options = list(self.options.keys())
if options:
self.settings["use"].choices = [(x, x) for x in options]
self.settings["use"].value = options
class SettingGroup(BaseModel): class SettingGroup(BaseModel):
application: BaseSettingGroup = Field(default_factory=BaseSettingGroup) application: BaseSettingGroup = Field(default_factory=BaseSettingGroup)

View File

@ -9,7 +9,7 @@ packages.find.exclude = ["tests*", "env*"]
[project] [project]
name = "ktem" name = "ktem"
version = "0.0.1" version = "0.1.0"
requires-python = ">= 3.10" requires-python = ">= 3.10"
description = "RAG-based Question and Answering Application" description = "RAG-based Question and Answering Application"
dependencies = [ dependencies = [

View File

@ -1,29 +0,0 @@
import time
from ktem.db.models import Conversation, Source, engine
from sqlmodel import Session
def add_conversation():
"""Add conversation to the manager."""
with Session(engine) as session:
c1 = Conversation(name="Conversation 1")
c2 = Conversation()
session.add(c1)
time.sleep(1)
session.add(c2)
time.sleep(1)
session.commit()
def add_files():
with Session(engine) as session:
s1 = Source(name="Source 1", path="Path 1")
s2 = Source(name="Source 2", path="Path 2")
session.add(s1)
session.add(s2)
session.commit()
# add_conversation()
add_files()

View File

@ -10,6 +10,8 @@ nav:
- Contributing: contributing.md - Contributing: contributing.md
- Application: - Application:
- Features: pages/app/features.md - Features: pages/app/features.md
- Index:
- File index: pages/app/index/file.md
- Customize flow logic: pages/app/customize-flows.md - Customize flow logic: pages/app/customize-flows.md
- Customize UI: pages/app/customize-ui.md - Customize UI: pages/app/customize-ui.md
- Functional description: pages/app/functional-description.md - Functional description: pages/app/functional-description.md