Allow file index to be private (#45)
* Fix breaking reranker * Allow private file index * Avoid setting default to 1 when user management is enabled
This commit is contained in:
parent
456f020caf
commit
e29bec6275
|
@ -109,11 +109,16 @@ class BaseIndex(abc.ABC):
|
|||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_indexing_pipeline(self, settings: dict) -> "BaseComponent":
|
||||
def get_indexing_pipeline(
|
||||
self, settings: dict, user_id: Optional[int]
|
||||
) -> "BaseComponent":
|
||||
"""Return the indexing pipeline that populates the entities into the index
|
||||
|
||||
Args:
|
||||
settings: the user settings of the index
|
||||
user_id: the user id who is accessing the index
|
||||
TODO: instead of having a user_id, should have an app_state
|
||||
which might also contain the settings.
|
||||
|
||||
Returns:
|
||||
BaseIndexing: the indexing pipeline
|
||||
|
|
|
@ -31,6 +31,26 @@ class FileIndex(BaseIndex):
|
|||
|
||||
def __init__(self, app, id: int, name: str, config: dict):
|
||||
super().__init__(app, id, name, config)
|
||||
|
||||
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
|
||||
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
|
||||
self._selector_ui_cls: Type
|
||||
self._selector_ui: Any = None
|
||||
self._index_ui_cls: Type
|
||||
self._index_ui: Any = None
|
||||
|
||||
self._default_settings: dict[str, dict] = {}
|
||||
self._setting_mappings: dict[str, dict] = {}
|
||||
|
||||
def _setup_resources(self):
|
||||
"""Setup resources for the file index
|
||||
|
||||
The resources include:
|
||||
- Database table
|
||||
- Vector store
|
||||
- Document store
|
||||
- File storage path
|
||||
"""
|
||||
Base = declarative_base()
|
||||
Source = type(
|
||||
"Source",
|
||||
|
@ -50,6 +70,7 @@ class FileIndex(BaseIndex):
|
|||
"date_created": Column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
),
|
||||
"user": Column(Integer, default=1),
|
||||
},
|
||||
)
|
||||
Index = type(
|
||||
|
@ -61,6 +82,7 @@ class FileIndex(BaseIndex):
|
|||
"source_id": Column(String),
|
||||
"target_id": Column(String),
|
||||
"relation_type": Column(Integer),
|
||||
"user": Column(Integer, default=1),
|
||||
},
|
||||
)
|
||||
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
|
||||
|
@ -74,16 +96,6 @@ class FileIndex(BaseIndex):
|
|||
"FileStoragePath": self._fs_path,
|
||||
}
|
||||
|
||||
self._indexing_pipeline_cls: Type[BaseFileIndexIndexing]
|
||||
self._retriever_pipeline_cls: list[Type[BaseFileIndexRetriever]]
|
||||
self._selector_ui_cls: Type
|
||||
self._selector_ui: Any = None
|
||||
self._index_ui_cls: Type
|
||||
self._index_ui: Any = None
|
||||
|
||||
self._default_settings: dict[str, dict] = {}
|
||||
self._setting_mappings: dict[str, dict] = {}
|
||||
|
||||
def _setup_indexing_cls(self):
|
||||
"""Retrieve the indexing class for the file index
|
||||
|
||||
|
@ -247,6 +259,7 @@ class FileIndex(BaseIndex):
|
|||
self.config = config
|
||||
|
||||
# create the resources
|
||||
self._setup_resources()
|
||||
self._resources["Source"].metadata.create_all(engine) # type: ignore
|
||||
self._resources["Index"].metadata.create_all(engine) # type: ignore
|
||||
self._fs_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -255,6 +268,7 @@ class FileIndex(BaseIndex):
|
|||
"""Clean up the index when the user delete it"""
|
||||
import shutil
|
||||
|
||||
self._setup_resources()
|
||||
self._resources["Source"].__table__.drop(engine) # type: ignore
|
||||
self._resources["Index"].__table__.drop(engine) # type: ignore
|
||||
self._vs.drop()
|
||||
|
@ -263,6 +277,7 @@ class FileIndex(BaseIndex):
|
|||
|
||||
def on_start(self):
|
||||
"""Setup the classes and hooks"""
|
||||
self._setup_resources()
|
||||
self._setup_indexing_cls()
|
||||
self._setup_retriever_cls()
|
||||
self._setup_file_index_ui_cls()
|
||||
|
@ -326,9 +341,16 @@ class FileIndex(BaseIndex):
|
|||
"Set 0 to disable."
|
||||
),
|
||||
},
|
||||
"private": {
|
||||
"name": "Make private",
|
||||
"value": False,
|
||||
"component": "radio",
|
||||
"choices": [("Yes", True), ("No", False)],
|
||||
"info": "If private, files will not be accessible across users.",
|
||||
},
|
||||
}
|
||||
|
||||
def get_indexing_pipeline(self, settings) -> BaseFileIndexIndexing:
|
||||
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||
"""Define the interface of the indexing pipeline"""
|
||||
|
||||
prefix = f"index.options.{self.id}."
|
||||
|
@ -341,6 +363,7 @@ class FileIndex(BaseIndex):
|
|||
|
||||
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
|
||||
obj.set_resources(resources=self._resources)
|
||||
obj._user_id = user_id
|
||||
|
||||
return obj
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ import gradio as gr
|
|||
from ktem.components import filestorage_path
|
||||
from ktem.db.models import engine
|
||||
from ktem.embeddings.manager import embedding_models_manager
|
||||
from ktem.llms.manager import llms
|
||||
from llama_index.vector_stores import (
|
||||
FilterCondition,
|
||||
FilterOperator,
|
||||
|
@ -28,7 +29,7 @@ from theflow.utils.modules import import_dotted_string
|
|||
from kotaemon.base import RetrievedDocument
|
||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||
from kotaemon.indices.ingests import DocumentIngestor
|
||||
from kotaemon.indices.rankings import BaseReranking
|
||||
from kotaemon.indices.rankings import BaseReranking, LLMReranking
|
||||
|
||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
|
||||
|
@ -72,7 +73,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
"""
|
||||
|
||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
|
||||
reranker: BaseReranking
|
||||
reranker: BaseReranking = LLMReranking.withx()
|
||||
get_extra_table: bool = False
|
||||
mmr: bool = False
|
||||
top_k: int = 5
|
||||
|
@ -225,12 +226,15 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
"""
|
||||
retriever = cls(
|
||||
get_extra_table=user_settings["prioritize_table"],
|
||||
reranker=user_settings["reranking_llm"],
|
||||
top_k=user_settings["num_retrieval"],
|
||||
mmr=user_settings["mmr"],
|
||||
)
|
||||
if not user_settings["use_reranking"]:
|
||||
retriever.reranker = None # type: ignore
|
||||
else:
|
||||
retriever.reranker.llm = llms.get(
|
||||
user_settings["reranking_llm"], llms.get_default()
|
||||
)
|
||||
|
||||
retriever.vector_retrieval.embedding = embedding_models_manager[
|
||||
index_settings.get("embedding", embedding_models_manager.get_default_name())
|
||||
|
@ -342,6 +346,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
|||
name=Path(file_path).name,
|
||||
path=file_hash,
|
||||
size=Path(file_path).stat().st_size,
|
||||
user=self._user_id, # type: ignore
|
||||
)
|
||||
file_to_source[file_path] = source
|
||||
|
||||
|
|
|
@ -168,6 +168,25 @@ class FileIndexPage(BasePage):
|
|||
|
||||
def on_subscribe_public_events(self):
|
||||
"""Subscribe to the declared public event of the app"""
|
||||
if self._app.f_user_management:
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.list_file,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.file_list_state, self.file_list],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": self.list_file,
|
||||
"inputs": [self._app.user_id],
|
||||
"outputs": [self.file_list_state, self.file_list],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
|
||||
def file_selected(self, file_id):
|
||||
if file_id is None:
|
||||
|
@ -257,7 +276,7 @@ class FileIndexPage(BasePage):
|
|||
)
|
||||
.then(
|
||||
fn=self.list_file,
|
||||
inputs=None,
|
||||
inputs=[self._app.user_id],
|
||||
outputs=[self.file_list_state, self.file_list],
|
||||
)
|
||||
)
|
||||
|
@ -294,12 +313,13 @@ class FileIndexPage(BasePage):
|
|||
self.files,
|
||||
self.reindex,
|
||||
self._app.settings_state,
|
||||
self._app.user_id,
|
||||
],
|
||||
outputs=[self.file_output],
|
||||
concurrency_limit=20,
|
||||
).then(
|
||||
fn=self.list_file,
|
||||
inputs=None,
|
||||
inputs=[self._app.user_id],
|
||||
outputs=[self.file_list_state, self.file_list],
|
||||
concurrency_limit=20,
|
||||
)
|
||||
|
@ -317,11 +337,11 @@ class FileIndexPage(BasePage):
|
|||
"""Called when the app is created"""
|
||||
self._app.app.load(
|
||||
self.list_file,
|
||||
inputs=None,
|
||||
inputs=[self._app.user_id],
|
||||
outputs=[self.file_list_state, self.file_list],
|
||||
)
|
||||
|
||||
def index_fn(self, files, reindex: bool, settings):
|
||||
def index_fn(self, files, reindex: bool, settings, user_id):
|
||||
"""Upload and index the files
|
||||
|
||||
Args:
|
||||
|
@ -342,7 +362,7 @@ class FileIndexPage(BasePage):
|
|||
gr.Info(f"Start indexing {len(files)} files...")
|
||||
|
||||
# get the pipeline
|
||||
indexing_pipeline = self._index.get_indexing_pipeline(settings)
|
||||
indexing_pipeline = self._index.get_indexing_pipeline(settings, user_id)
|
||||
|
||||
result = indexing_pipeline(files, reindex=reindex)
|
||||
if result is None:
|
||||
|
@ -360,7 +380,7 @@ class FileIndexPage(BasePage):
|
|||
|
||||
return gr.update(value=file_path, visible=True)
|
||||
|
||||
def index_files_from_dir(self, folder_path, reindex, settings):
|
||||
def index_files_from_dir(self, folder_path, reindex, settings, user_id):
|
||||
"""This should be constructable by users
|
||||
|
||||
It means that the users can build their own index.
|
||||
|
@ -428,12 +448,28 @@ class FileIndexPage(BasePage):
|
|||
for p in exclude_patterns:
|
||||
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
|
||||
|
||||
return self.index_fn(files, reindex, settings)
|
||||
return self.index_fn(files, reindex, settings, user_id)
|
||||
|
||||
def list_file(self, user_id):
|
||||
if user_id is None:
|
||||
# not signed in
|
||||
return [], pd.DataFrame.from_records(
|
||||
[
|
||||
{
|
||||
"id": "-",
|
||||
"name": "-",
|
||||
"size": "-",
|
||||
"text_length": "-",
|
||||
"date_created": "-",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def list_file(self):
|
||||
Source = self._index._resources["Source"]
|
||||
with Session(engine) as session:
|
||||
statement = select(Source)
|
||||
if self._index.config.get("private", False):
|
||||
statement = statement.where(Source.user == user_id)
|
||||
results = [
|
||||
{
|
||||
"id": each[0].id,
|
||||
|
@ -513,10 +549,12 @@ class FileSelector(BasePage):
|
|||
self.on_building_ui()
|
||||
|
||||
def default(self):
|
||||
return "disabled", []
|
||||
if self._app.f_user_management:
|
||||
return "disabled", [], -1
|
||||
return "disabled", [], 1
|
||||
|
||||
def on_building_ui(self):
|
||||
default_mode, default_selector = self.default()
|
||||
default_mode, default_selector, user_id = self.default()
|
||||
|
||||
self.mode = gr.Radio(
|
||||
value=default_mode,
|
||||
|
@ -529,25 +567,30 @@ class FileSelector(BasePage):
|
|||
)
|
||||
self.selector = gr.Dropdown(
|
||||
label="Files",
|
||||
choices=default_selector,
|
||||
value=default_selector,
|
||||
choices=[],
|
||||
multiselect=True,
|
||||
container=False,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
self.selector_user_id = gr.State(value=user_id)
|
||||
|
||||
def on_register_events(self):
|
||||
self.mode.change(
|
||||
fn=lambda mode: gr.update(visible=mode == "select"),
|
||||
inputs=[self.mode],
|
||||
outputs=[self.selector],
|
||||
fn=lambda mode, user_id: (gr.update(visible=mode == "select"), user_id),
|
||||
inputs=[self.mode, self._app.user_id],
|
||||
outputs=[self.selector, self.selector_user_id],
|
||||
)
|
||||
|
||||
def as_gradio_component(self):
|
||||
return [self.mode, self.selector]
|
||||
return [self.mode, self.selector, self.selector_user_id]
|
||||
|
||||
def get_selected_ids(self, components):
|
||||
mode, selected = components[0], components[1]
|
||||
mode, selected, user_id = components[0], components[1], components[2]
|
||||
if user_id is None:
|
||||
return []
|
||||
|
||||
if mode == "disabled":
|
||||
return []
|
||||
elif mode == "select":
|
||||
|
@ -556,17 +599,31 @@ class FileSelector(BasePage):
|
|||
file_ids = []
|
||||
with Session(engine) as session:
|
||||
statement = select(self._index._resources["Source"].id)
|
||||
if self._index.config.get("private", False):
|
||||
statement = statement.where(
|
||||
self._index._resources["Source"].user == user_id
|
||||
)
|
||||
results = session.execute(statement).all()
|
||||
for (id,) in results:
|
||||
file_ids.append(id)
|
||||
|
||||
return file_ids
|
||||
|
||||
def load_files(self, selected_files):
|
||||
options = []
|
||||
def load_files(self, selected_files, user_id):
|
||||
options: list = []
|
||||
available_ids = []
|
||||
if user_id is None:
|
||||
# not signed in
|
||||
return gr.update(value=selected_files, choices=options)
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(self._index._resources["Source"])
|
||||
if self._index.config.get("private", False):
|
||||
|
||||
statement = statement.where(
|
||||
self._index._resources["Source"].user == user_id
|
||||
)
|
||||
|
||||
results = session.execute(statement).all()
|
||||
for result in results:
|
||||
available_ids.append(result[0].id)
|
||||
|
@ -583,7 +640,7 @@ class FileSelector(BasePage):
|
|||
def _on_app_created(self):
|
||||
self._app.app.load(
|
||||
self.load_files,
|
||||
inputs=self.selector,
|
||||
inputs=[self.selector, self._app.user_id],
|
||||
outputs=[self.selector],
|
||||
)
|
||||
|
||||
|
@ -592,7 +649,26 @@ class FileSelector(BasePage):
|
|||
name=f"onFileIndex{self._index.id}Changed",
|
||||
definition={
|
||||
"fn": self.load_files,
|
||||
"inputs": [self.selector],
|
||||
"inputs": [self.selector, self._app.user_id],
|
||||
"outputs": [self.selector],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
if self._app.f_user_management:
|
||||
self._app.subscribe_event(
|
||||
name="onSignIn",
|
||||
definition={
|
||||
"fn": self.load_files,
|
||||
"inputs": [self.selector, self._app.user_id],
|
||||
"outputs": [self.selector],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
)
|
||||
self._app.subscribe_event(
|
||||
name="onSignOut",
|
||||
definition={
|
||||
"fn": self.load_files,
|
||||
"inputs": [self.selector, self._app.user_id],
|
||||
"outputs": [self.selector],
|
||||
"show_progress": "hidden",
|
||||
},
|
||||
|
|
|
@ -365,7 +365,7 @@ class ChatPage(BasePage):
|
|||
|
||||
Args:
|
||||
settings: the settings of the app
|
||||
is_regen: whether the regen button is clicked
|
||||
state: the state of the app
|
||||
selected: the list of file ids that will be served as context. If None, then
|
||||
consider using all files
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user