Allow file selector to be disabled (#36)

* Allow file selector to be disabled

* Update docs and variable names
This commit is contained in:
Duc Nguyen (john) 2024-04-16 18:43:56 +07:00 committed by GitHub
parent e19893a509
commit 1b2082a140
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 91 additions and 45 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 138 KiB

After

Width:  |  Height:  |  Size: 73 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 66 KiB

After

Width:  |  Height:  |  Size: 40 KiB

View File

@ -128,9 +128,12 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
1. Conversation Settings Panel
- Here you can select, create, rename, and delete conversations.
- By default, a new conversation is created automatically if no conversation is selected.
- Below that you have the file index, where you can select which files to retrieve references from.
- These are the files you have uploaded to the application from the `File Index` tab.
- If no file is selected, all files will be used.
- Below that you have the file index, where you can choose whether to disable, select all files, or select which files to retrieve references from.
- If you choose "Disabled", no files will be considered as context during chat.
- If you choose "Search All", all files will be considered during chat.
- If you choose "Select", a dropdown will appear for you to select the
files to be considered during chat. If no files are selected, then no
files will be considered during chat.
2. Chat Panel
- This is where you can chat with the chatbot.
3. Information panel

View File

@ -128,9 +128,12 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
1. Conversation Settings Panel
- Here you can select, create, rename, and delete conversations.
- By default, a new conversation is created automatically if no conversation is selected.
- Below that you have the file index, where you can select which files to retrieve references from.
- These are the files you have uploaded to the application from the `File Index` tab.
- If no file is selected, all files will be used.
- Below that you have the file index, where you can choose whether to disable, select all files, or select which files to retrieve references from.
- If you choose "Disabled", no files will be considered as context during chat.
- If you choose "Search All", all files will be considered during chat.
- If you choose "Select", a dropdown will appear for you to select the
files to be considered during chat. If no files are selected, then no
files will be considered during chat.
2. Chat Panel
- This is where you can chat with the chatbot.
3. Information panel

View File

@ -67,58 +67,63 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
documents
get_extra_table: if True, for each retrieved document, the pipeline will look
for surrounding tables (e.g. within the page)
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
"""
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
reranker: BaseReranking
get_extra_table: bool = False
mmr: bool = False
top_k: int = 5
def run(
self,
text: str,
top_k: int = 5,
mmr: bool = False,
doc_ids: Optional[list[str]] = None,
*args,
**kwargs,
) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text
Args:
text: the text to retrieve similar documents
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
doc_ids: list of document ids to constraint the retrieval
"""
if not doc_ids:
logger.info(f"Skip retrieval because of no selected files: {self}")
return []
Index = self._Index
kwargs = {}
if doc_ids:
with Session(engine) as session:
stmt = select(Index).where(
Index.relation_type == "vector",
Index.source_id.in_(doc_ids), # type: ignore
)
results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()]
kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
retrieval_kwargs = {}
with Session(engine) as session:
stmt = select(Index).where(
Index.relation_type == "vector",
Index.source_id.in_(doc_ids), # type: ignore
)
results = session.execute(stmt)
vs_ids = [r[0].target_id for r in results.all()]
if mmr:
retrieval_kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="doc_id",
value=vs_id,
operator=FilterOperator.EQ,
)
for vs_id in vs_ids
],
condition=FilterCondition.OR,
)
if self.mmr:
# TODO: double check that llama-index MMR works correctly
kwargs["mode"] = VectorStoreQueryMode.MMR
kwargs["mmr_threshold"] = 0.5
retrieval_kwargs["mode"] = VectorStoreQueryMode.MMR
retrieval_kwargs["mmr_threshold"] = 0.5
# rerank
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs)
if docs and self.get_from_path("reranker"):
docs = self.reranker(docs, query=text)
@ -221,6 +226,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
retriever = cls(
get_extra_table=user_settings["prioritize_table"],
reranker=user_settings["reranking_llm"],
top_k=user_settings["num_retrieval"],
mmr=user_settings["mmr"],
)
if not user_settings["use_reranking"]:
retriever.reranker = None # type: ignore
@ -228,11 +235,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
retriever.vector_retrieval.embedding = embedding_models_manager[
index_settings.get("embedding", embedding_models_manager.get_default_name())
]
kwargs = {
".top_k": int(user_settings["num_retrieval"]),
".mmr": user_settings["mmr"],
".doc_ids": selected,
}
kwargs = {".doc_ids": selected}
retriever.set_run(kwargs, temp=True)
return retriever

View File

@ -512,20 +512,55 @@ class FileSelector(BasePage):
self._index = index
self.on_building_ui()
def default(self):
return "disabled", []
def on_building_ui(self):
default_mode, default_selector = self.default()
self.mode = gr.Radio(
value=default_mode,
choices=[
("Disabled", "disabled"),
("Search All", "all"),
("Select", "select"),
],
container=False,
)
self.selector = gr.Dropdown(
label="Files",
choices=[],
choices=default_selector,
multiselect=True,
container=False,
interactive=True,
visible=False,
)
def on_register_events(self):
self.mode.change(
fn=lambda mode: gr.update(visible=mode == "select"),
inputs=[self.mode],
outputs=[self.selector],
)
def as_gradio_component(self):
return self.selector
return [self.mode, self.selector]
def get_selected_ids(self, selected):
return selected
def get_selected_ids(self, components):
mode, selected = components[0], components[1]
if mode == "disabled":
return []
elif mode == "select":
return selected
file_ids = []
with Session(engine) as session:
statement = select(self._index._resources["Source"].id)
results = session.execute(statement).all()
for (id,) in results:
file_ids.append(id)
return file_ids
def load_files(self, selected_files):
options = []

View File

@ -52,9 +52,11 @@ class ChatPage(BasePage):
len(self._indices_input) + len(gr_index),
)
)
index.default_selector = index_ui.default()
self._indices_input.extend(gr_index)
else:
index.selector = len(self._indices_input)
index.default_selector = index_ui.default()
self._indices_input.append(gr_index)
setattr(self, f"_index_{index.id}", index_ui)

View File

@ -156,9 +156,9 @@ class ConversationControl(BasePage):
if index.selector is None:
continue
if isinstance(index.selector, int):
indices.append(selected.get(str(index.id), []))
indices.append(selected.get(str(index.id), index.default_selector))
if isinstance(index.selector, tuple):
indices.extend(selected.get(str(index.id), [[]] * len(index.selector)))
indices.extend(selected.get(str(index.id), index.default_selector))
return id_, id_, name, chats, info_panel, state, *indices