Relate the retrievers to the indexer
This commit is contained in:
parent
9b586466ff
commit
c6637ca56e
|
@ -1,7 +1,13 @@
|
|||
from kotaemon.base import BaseComponent
|
||||
|
||||
|
||||
class BaseIndex(BaseComponent):
|
||||
class BaseRetriever(BaseComponent):
|
||||
pass
|
||||
|
||||
|
||||
class BaseIndexing(BaseComponent):
|
||||
"""The pipeline to index information into the data store"""
|
||||
|
||||
def get_user_settings(self) -> dict:
|
||||
"""Get the user settings for indexing
|
||||
|
||||
|
@ -12,5 +18,8 @@ class BaseIndex(BaseComponent):
|
|||
return {}
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, setting: dict) -> "BaseIndex":
|
||||
def get_pipeline(cls, settings: dict) -> "BaseIndexing":
|
||||
raise NotImplementedError
|
||||
|
||||
def get_retrievers(self, settings: dict, **kwargs) -> list[BaseRetriever]:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,16 +1,35 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from ktem.components import embeddings, filestorage_path, get_docstore, get_vectorstore
|
||||
from ktem.components import (
|
||||
embeddings,
|
||||
filestorage_path,
|
||||
get_docstore,
|
||||
get_vectorstore,
|
||||
llms,
|
||||
)
|
||||
from ktem.db.models import Index, Source, SourceTargetRelation, engine
|
||||
from ktem.indexing.base import BaseIndex
|
||||
from ktem.indexing.base import BaseIndexing, BaseRetriever
|
||||
from ktem.indexing.exceptions import FileExistsError
|
||||
from kotaemon.indices import VectorIndexing
|
||||
from kotaemon.base import RetrievedDocument
|
||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||
from kotaemon.indices.ingests import DocumentIngestor
|
||||
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
|
||||
from llama_index.vector_stores import (
|
||||
FilterCondition,
|
||||
FilterOperator,
|
||||
MetadataFilter,
|
||||
MetadataFilters,
|
||||
)
|
||||
from llama_index.vector_stores.types import VectorStoreQueryMode
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings
|
||||
|
||||
USER_SETTINGS = {
|
||||
"index_parser": {
|
||||
|
@ -61,7 +80,109 @@ USER_SETTINGS = {
|
|||
}
|
||||
|
||||
|
||||
class IndexDocumentPipeline(BaseIndex):
|
||||
class DocumentRetrievalPipeline(BaseRetriever):
|
||||
"""Retrieve relevant document
|
||||
|
||||
Args:
|
||||
vector_retrieval: the retrieval pipeline that return the relevant documents
|
||||
given a text query
|
||||
reranker: the reranking pipeline that re-rank and filter the retrieved
|
||||
documents
|
||||
get_extra_table: if True, for each retrieved document, the pipeline will look
|
||||
for surrounding tables (e.g. within the page)
|
||||
"""
|
||||
|
||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
|
||||
doc_store=get_docstore(),
|
||||
vector_store=get_vectorstore(),
|
||||
embedding=embeddings.get_default(),
|
||||
)
|
||||
reranker: BaseReranking = CohereReranking.withx(
|
||||
cohere_api_key=getattr(settings, "COHERE_API_KEY", "")
|
||||
) >> LLMReranking.withx(llm=llms.get_lowest_cost())
|
||||
get_extra_table: bool = False
|
||||
|
||||
def run(
|
||||
self,
|
||||
text: str,
|
||||
top_k: int = 5,
|
||||
mmr: bool = False,
|
||||
doc_ids: Optional[list[str]] = None,
|
||||
) -> list[RetrievedDocument]:
|
||||
"""Retrieve document excerpts similar to the text
|
||||
|
||||
Args:
|
||||
text: the text to retrieve similar documents
|
||||
top_k: number of documents to retrieve
|
||||
mmr: whether to use mmr to re-rank the documents
|
||||
doc_ids: list of document ids to constraint the retrieval
|
||||
"""
|
||||
kwargs = {}
|
||||
if doc_ids:
|
||||
with Session(engine) as session:
|
||||
stmt = select(Index).where(
|
||||
Index.relation_type == SourceTargetRelation.VECTOR,
|
||||
Index.source_id.in_(doc_ids), # type: ignore
|
||||
)
|
||||
results = session.exec(stmt)
|
||||
vs_ids = [r.target_id for r in results.all()]
|
||||
|
||||
kwargs["filters"] = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="doc_id",
|
||||
value=vs_id,
|
||||
operator=FilterOperator.EQ,
|
||||
)
|
||||
for vs_id in vs_ids
|
||||
],
|
||||
condition=FilterCondition.OR,
|
||||
)
|
||||
|
||||
if mmr:
|
||||
# TODO: double check that llama-index MMR works correctly
|
||||
kwargs["mode"] = VectorStoreQueryMode.MMR
|
||||
kwargs["mmr_threshold"] = 0.5
|
||||
|
||||
# rerank
|
||||
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
|
||||
if self.get_from_path("reranker"):
|
||||
docs = self.reranker(docs, query=text)
|
||||
|
||||
if not self.get_extra_table:
|
||||
return docs
|
||||
|
||||
# retrieve extra nodes relate to table
|
||||
table_pages = defaultdict(list)
|
||||
retrieved_id = set([doc.doc_id for doc in docs])
|
||||
for doc in docs:
|
||||
if "page_label" not in doc.metadata:
|
||||
continue
|
||||
if "file_name" not in doc.metadata:
|
||||
warnings.warn(
|
||||
"file_name not in metadata while page_label is in metadata: "
|
||||
f"{doc.metadata}"
|
||||
)
|
||||
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
|
||||
|
||||
queries = [
|
||||
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
|
||||
for fn, pls in table_pages.items()
|
||||
]
|
||||
if queries:
|
||||
extra_docs = self.vector_retrieval(
|
||||
text="",
|
||||
top_k=50,
|
||||
where={"$or": queries},
|
||||
)
|
||||
for doc in extra_docs:
|
||||
if doc.doc_id not in retrieved_id:
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class IndexDocumentPipeline(BaseIndexing):
|
||||
"""Store the documents and index the content into vector store and doc store
|
||||
|
||||
Args:
|
||||
|
@ -175,8 +296,29 @@ class IndexDocumentPipeline(BaseIndex):
|
|||
return USER_SETTINGS
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, setting) -> "IndexDocumentPipeline":
|
||||
def get_pipeline(cls, settings) -> "IndexDocumentPipeline":
|
||||
"""Get the pipeline based on the setting"""
|
||||
obj = cls()
|
||||
obj.file_ingestor.pdf_mode = setting["index.index_parser"]
|
||||
obj.file_ingestor.pdf_mode = settings["index.index_parser"]
|
||||
return obj
|
||||
|
||||
def get_retrievers(self, settings, **kwargs) -> list[BaseRetriever]:
|
||||
"""Get retriever objects associated with the index
|
||||
|
||||
Args:
|
||||
settings: the settings of the app
|
||||
kwargs: other arguments
|
||||
"""
|
||||
retriever = DocumentRetrievalPipeline(
|
||||
get_extra_table=settings["index.prioritize_table"]
|
||||
)
|
||||
if not settings["index.use_reranking"]:
|
||||
retriever.reranker = None # type: ignore
|
||||
|
||||
kwargs = {
|
||||
".top_k": int(settings["index.num_retrieval"]),
|
||||
".mmr": settings["index.mmr"],
|
||||
".doc_ids": kwargs.get("files", None),
|
||||
}
|
||||
retriever.set_run(kwargs, temp=True)
|
||||
return [retriever]
|
||||
|
|
|
@ -36,6 +36,7 @@ class ChatPage(BasePage):
|
|||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
|
@ -64,6 +65,7 @@ class ChatPage(BasePage):
|
|||
).then(
|
||||
fn=chat_fn,
|
||||
inputs=[
|
||||
self.chat_control.conversation_id,
|
||||
self.chat_panel.chatbot,
|
||||
self.data_source.files,
|
||||
self._app.settings_state,
|
||||
|
|
|
@ -7,8 +7,7 @@ from typing import Optional
|
|||
import gradio as gr
|
||||
from ktem.components import llms, reasonings
|
||||
from ktem.db.models import Conversation, Source, engine
|
||||
from ktem.indexing.base import BaseIndex
|
||||
from ktem.reasoning.simple import DocumentRetrievalPipeline
|
||||
from ktem.indexing.base import BaseIndexing
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings as app_settings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
@ -26,9 +25,15 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
|||
the pipeline objects
|
||||
"""
|
||||
|
||||
# get retrievers
|
||||
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
retrievers = indexing_cls.get_pipeline(settings).get_retrievers(
|
||||
settings, files=files
|
||||
)
|
||||
|
||||
reasoning_mode = settings["reasoning.use"]
|
||||
reasoning_cls = reasonings[reasoning_mode]
|
||||
pipeline = reasoning_cls.get_pipeline(settings, files=files)
|
||||
pipeline = reasoning_cls.get_pipeline(settings, retrievers, files=files)
|
||||
|
||||
if settings["reasoning.use"] in ["rewoo", "react"]:
|
||||
from kotaemon.agents import ReactAgent, RewooAgent
|
||||
|
@ -49,47 +54,38 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
|||
from kotaemon.agents import LLMTool
|
||||
|
||||
tools.append(LLMTool(llm=llm))
|
||||
elif tool == "docsearch":
|
||||
from kotaemon.agents import ComponentTool
|
||||
# elif tool == "docsearch":
|
||||
# pass
|
||||
|
||||
filenames = ""
|
||||
if files:
|
||||
with Session(engine) as session:
|
||||
statement = select(Source).where(
|
||||
Source.id.in_(files) # type: ignore
|
||||
)
|
||||
results = session.exec(statement).all()
|
||||
filenames = (
|
||||
"The file names are: "
|
||||
+ " ".join([result.name for result in results])
|
||||
+ ". "
|
||||
)
|
||||
# filenames = ""
|
||||
# if files:
|
||||
# with Session(engine) as session:
|
||||
# statement = select(Source).where(
|
||||
# Source.id.in_(files) # type: ignore
|
||||
# )
|
||||
# results = session.exec(statement).all()
|
||||
# filenames = (
|
||||
# "The file names are: "
|
||||
# + " ".join([result.name for result in results])
|
||||
# + ". "
|
||||
# )
|
||||
|
||||
retrieval_pipeline = DocumentRetrievalPipeline()
|
||||
retrieval_pipeline.set_run(
|
||||
{
|
||||
".top_k": int(settings["retrieval_number"]),
|
||||
".mmr": settings["retrieval_mmr"],
|
||||
".doc_ids": files,
|
||||
},
|
||||
temp=True,
|
||||
)
|
||||
tool = ComponentTool(
|
||||
name="docsearch",
|
||||
description=(
|
||||
"A vector store that searches for similar and "
|
||||
"related content "
|
||||
f"in a document. {filenames}"
|
||||
"The result is a huge chunk of text related "
|
||||
"to your search but can also "
|
||||
"contain irrelevant info."
|
||||
),
|
||||
component=retrieval_pipeline,
|
||||
postprocessor=lambda docs: "\n\n".join(
|
||||
[doc.text.replace("\n", " ") for doc in docs]
|
||||
),
|
||||
)
|
||||
tools.append(tool)
|
||||
# tool = ComponentTool(
|
||||
# name="docsearch",
|
||||
# description=(
|
||||
# "A vector store that searches for similar and "
|
||||
# "related content "
|
||||
# f"in a document. {filenames}"
|
||||
# "The result is a huge chunk of text related "
|
||||
# "to your search but can also "
|
||||
# "contain irrelevant info."
|
||||
# ),
|
||||
# component=retrieval_pipeline,
|
||||
# postprocessor=lambda docs: "\n\n".join(
|
||||
# [doc.text.replace("\n", " ") for doc in docs]
|
||||
# ),
|
||||
# )
|
||||
# tools.append(tool)
|
||||
elif tool == "google":
|
||||
from kotaemon.agents import GoogleSearchTool
|
||||
|
||||
|
@ -117,7 +113,7 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
|||
return pipeline
|
||||
|
||||
|
||||
async def chat_fn(chat_history, files, settings):
|
||||
async def chat_fn(conversation_id, chat_history, files, settings):
|
||||
"""Chat function"""
|
||||
chat_input = chat_history[-1][0]
|
||||
chat_history = chat_history[:-1]
|
||||
|
@ -128,7 +124,7 @@ async def chat_fn(chat_history, files, settings):
|
|||
pipeline = create_pipeline(settings, files)
|
||||
pipeline.set_output_queue(queue)
|
||||
|
||||
asyncio.create_task(pipeline(chat_input, chat_history))
|
||||
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
||||
text, refs = "", ""
|
||||
|
||||
while True:
|
||||
|
@ -207,7 +203,7 @@ def index_fn(files, reindex: bool, selected_files, settings):
|
|||
gr.Info(f"Start indexing {len(files)} files...")
|
||||
|
||||
# get the pipeline
|
||||
indexing_cls: BaseIndex = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||
indexing_pipeline = indexing_cls.get_pipeline(settings)
|
||||
|
||||
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import tiktoken
|
||||
from ktem.components import embeddings, get_docstore, get_vectorstore, llms
|
||||
from ktem.db.models import Index, SourceTargetRelation, engine
|
||||
from ktem.components import llms
|
||||
from ktem.indexing.base import BaseRetriever
|
||||
from kotaemon.base import (
|
||||
BaseComponent,
|
||||
Document,
|
||||
|
@ -16,126 +14,13 @@ from kotaemon.base import (
|
|||
RetrievedDocument,
|
||||
SystemMessage,
|
||||
)
|
||||
from kotaemon.indices import VectorRetrieval
|
||||
from kotaemon.indices.qa.citation import CitationPipeline
|
||||
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||
from llama_index.vector_stores import (
|
||||
FilterCondition,
|
||||
FilterOperator,
|
||||
MetadataFilter,
|
||||
MetadataFilters,
|
||||
)
|
||||
from llama_index.vector_stores.types import VectorStoreQueryMode
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentRetrievalPipeline(BaseComponent):
|
||||
"""Retrieve relevant document
|
||||
|
||||
Args:
|
||||
vector_retrieval: the retrieval pipeline that return the relevant documents
|
||||
given a text query
|
||||
reranker: the reranking pipeline that re-rank and filter the retrieved
|
||||
documents
|
||||
get_extra_table: if True, for each retrieved document, the pipeline will look
|
||||
for surrounding tables (e.g. within the page)
|
||||
"""
|
||||
|
||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
|
||||
doc_store=get_docstore(),
|
||||
vector_store=get_vectorstore(),
|
||||
embedding=embeddings.get_default(),
|
||||
)
|
||||
reranker: BaseReranking = CohereReranking.withx(
|
||||
cohere_api_key=getattr(settings, "COHERE_API_KEY", "")
|
||||
) >> LLMReranking.withx(llm=llms.get_lowest_cost())
|
||||
get_extra_table: bool = False
|
||||
|
||||
def run(
|
||||
self,
|
||||
text: str,
|
||||
top_k: int = 5,
|
||||
mmr: bool = False,
|
||||
doc_ids: Optional[list[str]] = None,
|
||||
) -> list[RetrievedDocument]:
|
||||
"""Retrieve document excerpts similar to the text
|
||||
|
||||
Args:
|
||||
text: the text to retrieve similar documents
|
||||
top_k: number of documents to retrieve
|
||||
mmr: whether to use mmr to re-rank the documents
|
||||
doc_ids: list of document ids to constraint the retrieval
|
||||
"""
|
||||
kwargs = {}
|
||||
if doc_ids:
|
||||
with Session(engine) as session:
|
||||
stmt = select(Index).where(
|
||||
Index.relation_type == SourceTargetRelation.VECTOR,
|
||||
Index.source_id.in_(doc_ids), # type: ignore
|
||||
)
|
||||
results = session.exec(stmt)
|
||||
vs_ids = [r.target_id for r in results.all()]
|
||||
|
||||
kwargs["filters"] = MetadataFilters(
|
||||
filters=[
|
||||
MetadataFilter(
|
||||
key="doc_id",
|
||||
value=vs_id,
|
||||
operator=FilterOperator.EQ,
|
||||
)
|
||||
for vs_id in vs_ids
|
||||
],
|
||||
condition=FilterCondition.OR,
|
||||
)
|
||||
|
||||
if mmr:
|
||||
# TODO: double check that llama-index MMR works correctly
|
||||
kwargs["mode"] = VectorStoreQueryMode.MMR
|
||||
kwargs["mmr_threshold"] = 0.5
|
||||
|
||||
# rerank
|
||||
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
|
||||
if self.get_from_path("reranker"):
|
||||
docs = self.reranker(docs, query=text)
|
||||
|
||||
if not self.get_extra_table:
|
||||
return docs
|
||||
|
||||
# retrieve extra nodes relate to table
|
||||
table_pages = defaultdict(list)
|
||||
retrieved_id = set([doc.doc_id for doc in docs])
|
||||
for doc in docs:
|
||||
if "page_label" not in doc.metadata:
|
||||
continue
|
||||
if "file_name" not in doc.metadata:
|
||||
warnings.warn(
|
||||
"file_name not in metadata while page_label is in metadata: "
|
||||
f"{doc.metadata}"
|
||||
)
|
||||
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
|
||||
|
||||
queries = [
|
||||
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
|
||||
for fn, pls in table_pages.items()
|
||||
]
|
||||
if queries:
|
||||
extra_docs = self.vector_retrieval(
|
||||
text="",
|
||||
top_k=50,
|
||||
where={"$or": queries},
|
||||
)
|
||||
for doc in extra_docs:
|
||||
if doc.doc_id not in retrieved_id:
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
class PrepareEvidencePipeline(BaseComponent):
|
||||
"""Prepare the evidence text from the list of retrieved documents
|
||||
|
||||
|
@ -338,22 +223,22 @@ class FullQAPipeline(BaseComponent):
|
|||
allow_extra = True
|
||||
params_publish = True
|
||||
|
||||
retrieval_pipeline: DocumentRetrievalPipeline = DocumentRetrievalPipeline.withx()
|
||||
retrievers: list[BaseRetriever]
|
||||
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
||||
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
||||
|
||||
async def run( # type: ignore
|
||||
self, question: str, history: list, **kwargs # type: ignore
|
||||
self, message: str, cid: str, history: list, **kwargs # type: ignore
|
||||
) -> Document: # type: ignore
|
||||
docs = self.retrieval_pipeline(text=question)
|
||||
docs = []
|
||||
for retriever in self.retrievers:
|
||||
docs.extend(retriever(text=message))
|
||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||
answer = await self.answering_pipeline(
|
||||
question=question, evidence=evidence, evidence_mode=evidence_mode
|
||||
question=message, evidence=evidence, evidence_mode=evidence_mode
|
||||
)
|
||||
|
||||
# prepare citation
|
||||
from collections import defaultdict
|
||||
|
||||
spans = defaultdict(list)
|
||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||
for quote in fact_with_evidence.substring_quote:
|
||||
|
@ -369,6 +254,7 @@ class FullQAPipeline(BaseComponent):
|
|||
break
|
||||
|
||||
id2docs = {doc.doc_id: doc for doc in docs}
|
||||
lack_evidence = True
|
||||
for id, ss in spans.items():
|
||||
if not ss:
|
||||
continue
|
||||
|
@ -391,31 +277,24 @@ class FullQAPipeline(BaseComponent):
|
|||
)
|
||||
}
|
||||
)
|
||||
lack_evidence = False
|
||||
|
||||
if lack_evidence:
|
||||
self.report_output({"evidence": "No evidence found"})
|
||||
|
||||
self.report_output(None)
|
||||
return answer
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, settings, **kwargs):
|
||||
def get_pipeline(cls, settings, retrievers, **kwargs):
|
||||
"""Get the reasoning pipeline
|
||||
|
||||
Need a base pipeline implementation. Currently the drawback is that we want to
|
||||
treat the retrievers as tools. Hence, the reasoning pipelie should just take
|
||||
the already initiated tools (retrievers), and do not need to set such logic
|
||||
here.
|
||||
Args:
|
||||
settings: the settings for the pipeline
|
||||
retrievers: the retrievers to use
|
||||
"""
|
||||
pipeline = FullQAPipeline(get_extra_table=settings["index.prioritize_table"])
|
||||
if not settings["index.use_reranking"]:
|
||||
pipeline.retrieval_pipeline.reranker = None # type: ignore
|
||||
|
||||
pipeline = FullQAPipeline(retrievers=retrievers)
|
||||
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
|
||||
kwargs = {
|
||||
".retrieval_pipeline.top_k": int(settings["index.num_retrieval"]),
|
||||
".retrieval_pipeline.mmr": settings["index.mmr"],
|
||||
".retrieval_pipeline.doc_ids": kwargs.get("files", None),
|
||||
}
|
||||
pipeline.set_run(kwargs, temp=True)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
|
|
Loading…
Reference in New Issue
Block a user