Relate the retrievers to the indexer
This commit is contained in:
parent
9b586466ff
commit
c6637ca56e
|
@ -1,7 +1,13 @@
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
|
|
||||||
|
|
||||||
class BaseIndex(BaseComponent):
|
class BaseRetriever(BaseComponent):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseIndexing(BaseComponent):
|
||||||
|
"""The pipeline to index information into the data store"""
|
||||||
|
|
||||||
def get_user_settings(self) -> dict:
|
def get_user_settings(self) -> dict:
|
||||||
"""Get the user settings for indexing
|
"""Get the user settings for indexing
|
||||||
|
|
||||||
|
@ -12,5 +18,8 @@ class BaseIndex(BaseComponent):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pipeline(cls, setting: dict) -> "BaseIndex":
|
def get_pipeline(cls, settings: dict) -> "BaseIndexing":
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_retrievers(self, settings: dict, **kwargs) -> list[BaseRetriever]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -1,16 +1,35 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from ktem.components import embeddings, filestorage_path, get_docstore, get_vectorstore
|
from ktem.components import (
|
||||||
|
embeddings,
|
||||||
|
filestorage_path,
|
||||||
|
get_docstore,
|
||||||
|
get_vectorstore,
|
||||||
|
llms,
|
||||||
|
)
|
||||||
from ktem.db.models import Index, Source, SourceTargetRelation, engine
|
from ktem.db.models import Index, Source, SourceTargetRelation, engine
|
||||||
from ktem.indexing.base import BaseIndex
|
from ktem.indexing.base import BaseIndexing, BaseRetriever
|
||||||
from ktem.indexing.exceptions import FileExistsError
|
from ktem.indexing.exceptions import FileExistsError
|
||||||
from kotaemon.indices import VectorIndexing
|
from kotaemon.base import RetrievedDocument
|
||||||
|
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||||
from kotaemon.indices.ingests import DocumentIngestor
|
from kotaemon.indices.ingests import DocumentIngestor
|
||||||
|
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
|
||||||
|
from llama_index.vector_stores import (
|
||||||
|
FilterCondition,
|
||||||
|
FilterOperator,
|
||||||
|
MetadataFilter,
|
||||||
|
MetadataFilters,
|
||||||
|
)
|
||||||
|
from llama_index.vector_stores.types import VectorStoreQueryMode
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
|
from theflow.settings import settings
|
||||||
|
|
||||||
USER_SETTINGS = {
|
USER_SETTINGS = {
|
||||||
"index_parser": {
|
"index_parser": {
|
||||||
|
@ -61,7 +80,109 @@ USER_SETTINGS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class IndexDocumentPipeline(BaseIndex):
|
class DocumentRetrievalPipeline(BaseRetriever):
|
||||||
|
"""Retrieve relevant document
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_retrieval: the retrieval pipeline that return the relevant documents
|
||||||
|
given a text query
|
||||||
|
reranker: the reranking pipeline that re-rank and filter the retrieved
|
||||||
|
documents
|
||||||
|
get_extra_table: if True, for each retrieved document, the pipeline will look
|
||||||
|
for surrounding tables (e.g. within the page)
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
|
||||||
|
doc_store=get_docstore(),
|
||||||
|
vector_store=get_vectorstore(),
|
||||||
|
embedding=embeddings.get_default(),
|
||||||
|
)
|
||||||
|
reranker: BaseReranking = CohereReranking.withx(
|
||||||
|
cohere_api_key=getattr(settings, "COHERE_API_KEY", "")
|
||||||
|
) >> LLMReranking.withx(llm=llms.get_lowest_cost())
|
||||||
|
get_extra_table: bool = False
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
mmr: bool = False,
|
||||||
|
doc_ids: Optional[list[str]] = None,
|
||||||
|
) -> list[RetrievedDocument]:
|
||||||
|
"""Retrieve document excerpts similar to the text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: the text to retrieve similar documents
|
||||||
|
top_k: number of documents to retrieve
|
||||||
|
mmr: whether to use mmr to re-rank the documents
|
||||||
|
doc_ids: list of document ids to constraint the retrieval
|
||||||
|
"""
|
||||||
|
kwargs = {}
|
||||||
|
if doc_ids:
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = select(Index).where(
|
||||||
|
Index.relation_type == SourceTargetRelation.VECTOR,
|
||||||
|
Index.source_id.in_(doc_ids), # type: ignore
|
||||||
|
)
|
||||||
|
results = session.exec(stmt)
|
||||||
|
vs_ids = [r.target_id for r in results.all()]
|
||||||
|
|
||||||
|
kwargs["filters"] = MetadataFilters(
|
||||||
|
filters=[
|
||||||
|
MetadataFilter(
|
||||||
|
key="doc_id",
|
||||||
|
value=vs_id,
|
||||||
|
operator=FilterOperator.EQ,
|
||||||
|
)
|
||||||
|
for vs_id in vs_ids
|
||||||
|
],
|
||||||
|
condition=FilterCondition.OR,
|
||||||
|
)
|
||||||
|
|
||||||
|
if mmr:
|
||||||
|
# TODO: double check that llama-index MMR works correctly
|
||||||
|
kwargs["mode"] = VectorStoreQueryMode.MMR
|
||||||
|
kwargs["mmr_threshold"] = 0.5
|
||||||
|
|
||||||
|
# rerank
|
||||||
|
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
|
||||||
|
if self.get_from_path("reranker"):
|
||||||
|
docs = self.reranker(docs, query=text)
|
||||||
|
|
||||||
|
if not self.get_extra_table:
|
||||||
|
return docs
|
||||||
|
|
||||||
|
# retrieve extra nodes relate to table
|
||||||
|
table_pages = defaultdict(list)
|
||||||
|
retrieved_id = set([doc.doc_id for doc in docs])
|
||||||
|
for doc in docs:
|
||||||
|
if "page_label" not in doc.metadata:
|
||||||
|
continue
|
||||||
|
if "file_name" not in doc.metadata:
|
||||||
|
warnings.warn(
|
||||||
|
"file_name not in metadata while page_label is in metadata: "
|
||||||
|
f"{doc.metadata}"
|
||||||
|
)
|
||||||
|
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
|
||||||
|
|
||||||
|
queries = [
|
||||||
|
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
|
||||||
|
for fn, pls in table_pages.items()
|
||||||
|
]
|
||||||
|
if queries:
|
||||||
|
extra_docs = self.vector_retrieval(
|
||||||
|
text="",
|
||||||
|
top_k=50,
|
||||||
|
where={"$or": queries},
|
||||||
|
)
|
||||||
|
for doc in extra_docs:
|
||||||
|
if doc.doc_id not in retrieved_id:
|
||||||
|
docs.append(doc)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
class IndexDocumentPipeline(BaseIndexing):
|
||||||
"""Store the documents and index the content into vector store and doc store
|
"""Store the documents and index the content into vector store and doc store
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -175,8 +296,29 @@ class IndexDocumentPipeline(BaseIndex):
|
||||||
return USER_SETTINGS
|
return USER_SETTINGS
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pipeline(cls, setting) -> "IndexDocumentPipeline":
|
def get_pipeline(cls, settings) -> "IndexDocumentPipeline":
|
||||||
"""Get the pipeline based on the setting"""
|
"""Get the pipeline based on the setting"""
|
||||||
obj = cls()
|
obj = cls()
|
||||||
obj.file_ingestor.pdf_mode = setting["index.index_parser"]
|
obj.file_ingestor.pdf_mode = settings["index.index_parser"]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
def get_retrievers(self, settings, **kwargs) -> list[BaseRetriever]:
|
||||||
|
"""Get retriever objects associated with the index
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: the settings of the app
|
||||||
|
kwargs: other arguments
|
||||||
|
"""
|
||||||
|
retriever = DocumentRetrievalPipeline(
|
||||||
|
get_extra_table=settings["index.prioritize_table"]
|
||||||
|
)
|
||||||
|
if not settings["index.use_reranking"]:
|
||||||
|
retriever.reranker = None # type: ignore
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
".top_k": int(settings["index.num_retrieval"]),
|
||||||
|
".mmr": settings["index.mmr"],
|
||||||
|
".doc_ids": kwargs.get("files", None),
|
||||||
|
}
|
||||||
|
retriever.set_run(kwargs, temp=True)
|
||||||
|
return [retriever]
|
||||||
|
|
|
@ -36,6 +36,7 @@ class ChatPage(BasePage):
|
||||||
).then(
|
).then(
|
||||||
fn=chat_fn,
|
fn=chat_fn,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.data_source.files,
|
self.data_source.files,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
|
@ -64,6 +65,7 @@ class ChatPage(BasePage):
|
||||||
).then(
|
).then(
|
||||||
fn=chat_fn,
|
fn=chat_fn,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.data_source.files,
|
self.data_source.files,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
|
|
|
@ -7,8 +7,7 @@ from typing import Optional
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from ktem.components import llms, reasonings
|
from ktem.components import llms, reasonings
|
||||||
from ktem.db.models import Conversation, Source, engine
|
from ktem.db.models import Conversation, Source, engine
|
||||||
from ktem.indexing.base import BaseIndex
|
from ktem.indexing.base import BaseIndexing
|
||||||
from ktem.reasoning.simple import DocumentRetrievalPipeline
|
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from theflow.settings import settings as app_settings
|
from theflow.settings import settings as app_settings
|
||||||
from theflow.utils.modules import import_dotted_string
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
@ -26,9 +25,15 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||||
the pipeline objects
|
the pipeline objects
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# get retrievers
|
||||||
|
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||||
|
retrievers = indexing_cls.get_pipeline(settings).get_retrievers(
|
||||||
|
settings, files=files
|
||||||
|
)
|
||||||
|
|
||||||
reasoning_mode = settings["reasoning.use"]
|
reasoning_mode = settings["reasoning.use"]
|
||||||
reasoning_cls = reasonings[reasoning_mode]
|
reasoning_cls = reasonings[reasoning_mode]
|
||||||
pipeline = reasoning_cls.get_pipeline(settings, files=files)
|
pipeline = reasoning_cls.get_pipeline(settings, retrievers, files=files)
|
||||||
|
|
||||||
if settings["reasoning.use"] in ["rewoo", "react"]:
|
if settings["reasoning.use"] in ["rewoo", "react"]:
|
||||||
from kotaemon.agents import ReactAgent, RewooAgent
|
from kotaemon.agents import ReactAgent, RewooAgent
|
||||||
|
@ -49,47 +54,38 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||||
from kotaemon.agents import LLMTool
|
from kotaemon.agents import LLMTool
|
||||||
|
|
||||||
tools.append(LLMTool(llm=llm))
|
tools.append(LLMTool(llm=llm))
|
||||||
elif tool == "docsearch":
|
# elif tool == "docsearch":
|
||||||
from kotaemon.agents import ComponentTool
|
# pass
|
||||||
|
|
||||||
filenames = ""
|
# filenames = ""
|
||||||
if files:
|
# if files:
|
||||||
with Session(engine) as session:
|
# with Session(engine) as session:
|
||||||
statement = select(Source).where(
|
# statement = select(Source).where(
|
||||||
Source.id.in_(files) # type: ignore
|
# Source.id.in_(files) # type: ignore
|
||||||
)
|
# )
|
||||||
results = session.exec(statement).all()
|
# results = session.exec(statement).all()
|
||||||
filenames = (
|
# filenames = (
|
||||||
"The file names are: "
|
# "The file names are: "
|
||||||
+ " ".join([result.name for result in results])
|
# + " ".join([result.name for result in results])
|
||||||
+ ". "
|
# + ". "
|
||||||
)
|
# )
|
||||||
|
|
||||||
retrieval_pipeline = DocumentRetrievalPipeline()
|
# tool = ComponentTool(
|
||||||
retrieval_pipeline.set_run(
|
# name="docsearch",
|
||||||
{
|
# description=(
|
||||||
".top_k": int(settings["retrieval_number"]),
|
# "A vector store that searches for similar and "
|
||||||
".mmr": settings["retrieval_mmr"],
|
# "related content "
|
||||||
".doc_ids": files,
|
# f"in a document. {filenames}"
|
||||||
},
|
# "The result is a huge chunk of text related "
|
||||||
temp=True,
|
# "to your search but can also "
|
||||||
)
|
# "contain irrelevant info."
|
||||||
tool = ComponentTool(
|
# ),
|
||||||
name="docsearch",
|
# component=retrieval_pipeline,
|
||||||
description=(
|
# postprocessor=lambda docs: "\n\n".join(
|
||||||
"A vector store that searches for similar and "
|
# [doc.text.replace("\n", " ") for doc in docs]
|
||||||
"related content "
|
# ),
|
||||||
f"in a document. {filenames}"
|
# )
|
||||||
"The result is a huge chunk of text related "
|
# tools.append(tool)
|
||||||
"to your search but can also "
|
|
||||||
"contain irrelevant info."
|
|
||||||
),
|
|
||||||
component=retrieval_pipeline,
|
|
||||||
postprocessor=lambda docs: "\n\n".join(
|
|
||||||
[doc.text.replace("\n", " ") for doc in docs]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
tools.append(tool)
|
|
||||||
elif tool == "google":
|
elif tool == "google":
|
||||||
from kotaemon.agents import GoogleSearchTool
|
from kotaemon.agents import GoogleSearchTool
|
||||||
|
|
||||||
|
@ -117,7 +113,7 @@ def create_pipeline(settings: dict, files: Optional[list] = None):
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
async def chat_fn(chat_history, files, settings):
|
async def chat_fn(conversation_id, chat_history, files, settings):
|
||||||
"""Chat function"""
|
"""Chat function"""
|
||||||
chat_input = chat_history[-1][0]
|
chat_input = chat_history[-1][0]
|
||||||
chat_history = chat_history[:-1]
|
chat_history = chat_history[:-1]
|
||||||
|
@ -128,7 +124,7 @@ async def chat_fn(chat_history, files, settings):
|
||||||
pipeline = create_pipeline(settings, files)
|
pipeline = create_pipeline(settings, files)
|
||||||
pipeline.set_output_queue(queue)
|
pipeline.set_output_queue(queue)
|
||||||
|
|
||||||
asyncio.create_task(pipeline(chat_input, chat_history))
|
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
||||||
text, refs = "", ""
|
text, refs = "", ""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
@ -207,7 +203,7 @@ def index_fn(files, reindex: bool, selected_files, settings):
|
||||||
gr.Info(f"Start indexing {len(files)} files...")
|
gr.Info(f"Start indexing {len(files)} files...")
|
||||||
|
|
||||||
# get the pipeline
|
# get the pipeline
|
||||||
indexing_cls: BaseIndex = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
indexing_cls: BaseIndexing = import_dotted_string(app_settings.KH_INDEX, safe=False)
|
||||||
indexing_pipeline = indexing_cls.get_pipeline(settings)
|
indexing_pipeline = indexing_cls.get_pipeline(settings)
|
||||||
|
|
||||||
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)
|
output_nodes, file_ids = indexing_pipeline(files, reindex=reindex)
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from ktem.components import embeddings, get_docstore, get_vectorstore, llms
|
from ktem.components import llms
|
||||||
from ktem.db.models import Index, SourceTargetRelation, engine
|
from ktem.indexing.base import BaseRetriever
|
||||||
from kotaemon.base import (
|
from kotaemon.base import (
|
||||||
BaseComponent,
|
BaseComponent,
|
||||||
Document,
|
Document,
|
||||||
|
@ -16,126 +14,13 @@ from kotaemon.base import (
|
||||||
RetrievedDocument,
|
RetrievedDocument,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from kotaemon.indices import VectorRetrieval
|
|
||||||
from kotaemon.indices.qa.citation import CitationPipeline
|
from kotaemon.indices.qa.citation import CitationPipeline
|
||||||
from kotaemon.indices.rankings import BaseReranking, CohereReranking, LLMReranking
|
|
||||||
from kotaemon.indices.splitters import TokenSplitter
|
from kotaemon.indices.splitters import TokenSplitter
|
||||||
from kotaemon.llms import ChatLLM, PromptTemplate
|
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||||
from llama_index.vector_stores import (
|
|
||||||
FilterCondition,
|
|
||||||
FilterOperator,
|
|
||||||
MetadataFilter,
|
|
||||||
MetadataFilters,
|
|
||||||
)
|
|
||||||
from llama_index.vector_stores.types import VectorStoreQueryMode
|
|
||||||
from sqlmodel import Session, select
|
|
||||||
from theflow.settings import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DocumentRetrievalPipeline(BaseComponent):
|
|
||||||
"""Retrieve relevant document
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vector_retrieval: the retrieval pipeline that return the relevant documents
|
|
||||||
given a text query
|
|
||||||
reranker: the reranking pipeline that re-rank and filter the retrieved
|
|
||||||
documents
|
|
||||||
get_extra_table: if True, for each retrieved document, the pipeline will look
|
|
||||||
for surrounding tables (e.g. within the page)
|
|
||||||
"""
|
|
||||||
|
|
||||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx(
|
|
||||||
doc_store=get_docstore(),
|
|
||||||
vector_store=get_vectorstore(),
|
|
||||||
embedding=embeddings.get_default(),
|
|
||||||
)
|
|
||||||
reranker: BaseReranking = CohereReranking.withx(
|
|
||||||
cohere_api_key=getattr(settings, "COHERE_API_KEY", "")
|
|
||||||
) >> LLMReranking.withx(llm=llms.get_lowest_cost())
|
|
||||||
get_extra_table: bool = False
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
text: str,
|
|
||||||
top_k: int = 5,
|
|
||||||
mmr: bool = False,
|
|
||||||
doc_ids: Optional[list[str]] = None,
|
|
||||||
) -> list[RetrievedDocument]:
|
|
||||||
"""Retrieve document excerpts similar to the text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: the text to retrieve similar documents
|
|
||||||
top_k: number of documents to retrieve
|
|
||||||
mmr: whether to use mmr to re-rank the documents
|
|
||||||
doc_ids: list of document ids to constraint the retrieval
|
|
||||||
"""
|
|
||||||
kwargs = {}
|
|
||||||
if doc_ids:
|
|
||||||
with Session(engine) as session:
|
|
||||||
stmt = select(Index).where(
|
|
||||||
Index.relation_type == SourceTargetRelation.VECTOR,
|
|
||||||
Index.source_id.in_(doc_ids), # type: ignore
|
|
||||||
)
|
|
||||||
results = session.exec(stmt)
|
|
||||||
vs_ids = [r.target_id for r in results.all()]
|
|
||||||
|
|
||||||
kwargs["filters"] = MetadataFilters(
|
|
||||||
filters=[
|
|
||||||
MetadataFilter(
|
|
||||||
key="doc_id",
|
|
||||||
value=vs_id,
|
|
||||||
operator=FilterOperator.EQ,
|
|
||||||
)
|
|
||||||
for vs_id in vs_ids
|
|
||||||
],
|
|
||||||
condition=FilterCondition.OR,
|
|
||||||
)
|
|
||||||
|
|
||||||
if mmr:
|
|
||||||
# TODO: double check that llama-index MMR works correctly
|
|
||||||
kwargs["mode"] = VectorStoreQueryMode.MMR
|
|
||||||
kwargs["mmr_threshold"] = 0.5
|
|
||||||
|
|
||||||
# rerank
|
|
||||||
docs = self.vector_retrieval(text=text, top_k=top_k, **kwargs)
|
|
||||||
if self.get_from_path("reranker"):
|
|
||||||
docs = self.reranker(docs, query=text)
|
|
||||||
|
|
||||||
if not self.get_extra_table:
|
|
||||||
return docs
|
|
||||||
|
|
||||||
# retrieve extra nodes relate to table
|
|
||||||
table_pages = defaultdict(list)
|
|
||||||
retrieved_id = set([doc.doc_id for doc in docs])
|
|
||||||
for doc in docs:
|
|
||||||
if "page_label" not in doc.metadata:
|
|
||||||
continue
|
|
||||||
if "file_name" not in doc.metadata:
|
|
||||||
warnings.warn(
|
|
||||||
"file_name not in metadata while page_label is in metadata: "
|
|
||||||
f"{doc.metadata}"
|
|
||||||
)
|
|
||||||
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
|
|
||||||
|
|
||||||
queries = [
|
|
||||||
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
|
|
||||||
for fn, pls in table_pages.items()
|
|
||||||
]
|
|
||||||
if queries:
|
|
||||||
extra_docs = self.vector_retrieval(
|
|
||||||
text="",
|
|
||||||
top_k=50,
|
|
||||||
where={"$or": queries},
|
|
||||||
)
|
|
||||||
for doc in extra_docs:
|
|
||||||
if doc.doc_id not in retrieved_id:
|
|
||||||
docs.append(doc)
|
|
||||||
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
class PrepareEvidencePipeline(BaseComponent):
|
class PrepareEvidencePipeline(BaseComponent):
|
||||||
"""Prepare the evidence text from the list of retrieved documents
|
"""Prepare the evidence text from the list of retrieved documents
|
||||||
|
|
||||||
|
@ -338,22 +223,22 @@ class FullQAPipeline(BaseComponent):
|
||||||
allow_extra = True
|
allow_extra = True
|
||||||
params_publish = True
|
params_publish = True
|
||||||
|
|
||||||
retrieval_pipeline: DocumentRetrievalPipeline = DocumentRetrievalPipeline.withx()
|
retrievers: list[BaseRetriever]
|
||||||
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
evidence_pipeline: PrepareEvidencePipeline = PrepareEvidencePipeline.withx()
|
||||||
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
answering_pipeline: AnswerWithContextPipeline = AnswerWithContextPipeline.withx()
|
||||||
|
|
||||||
async def run( # type: ignore
|
async def run( # type: ignore
|
||||||
self, question: str, history: list, **kwargs # type: ignore
|
self, message: str, cid: str, history: list, **kwargs # type: ignore
|
||||||
) -> Document: # type: ignore
|
) -> Document: # type: ignore
|
||||||
docs = self.retrieval_pipeline(text=question)
|
docs = []
|
||||||
|
for retriever in self.retrievers:
|
||||||
|
docs.extend(retriever(text=message))
|
||||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||||
answer = await self.answering_pipeline(
|
answer = await self.answering_pipeline(
|
||||||
question=question, evidence=evidence, evidence_mode=evidence_mode
|
question=message, evidence=evidence, evidence_mode=evidence_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# prepare citation
|
# prepare citation
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
spans = defaultdict(list)
|
spans = defaultdict(list)
|
||||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||||
for quote in fact_with_evidence.substring_quote:
|
for quote in fact_with_evidence.substring_quote:
|
||||||
|
@ -369,6 +254,7 @@ class FullQAPipeline(BaseComponent):
|
||||||
break
|
break
|
||||||
|
|
||||||
id2docs = {doc.doc_id: doc for doc in docs}
|
id2docs = {doc.doc_id: doc for doc in docs}
|
||||||
|
lack_evidence = True
|
||||||
for id, ss in spans.items():
|
for id, ss in spans.items():
|
||||||
if not ss:
|
if not ss:
|
||||||
continue
|
continue
|
||||||
|
@ -391,31 +277,24 @@ class FullQAPipeline(BaseComponent):
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
lack_evidence = False
|
||||||
|
|
||||||
|
if lack_evidence:
|
||||||
|
self.report_output({"evidence": "No evidence found"})
|
||||||
|
|
||||||
self.report_output(None)
|
self.report_output(None)
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pipeline(cls, settings, **kwargs):
|
def get_pipeline(cls, settings, retrievers, **kwargs):
|
||||||
"""Get the reasoning pipeline
|
"""Get the reasoning pipeline
|
||||||
|
|
||||||
Need a base pipeline implementation. Currently the drawback is that we want to
|
Args:
|
||||||
treat the retrievers as tools. Hence, the reasoning pipelie should just take
|
settings: the settings for the pipeline
|
||||||
the already initiated tools (retrievers), and do not need to set such logic
|
retrievers: the retrievers to use
|
||||||
here.
|
|
||||||
"""
|
"""
|
||||||
pipeline = FullQAPipeline(get_extra_table=settings["index.prioritize_table"])
|
pipeline = FullQAPipeline(retrievers=retrievers)
|
||||||
if not settings["index.use_reranking"]:
|
|
||||||
pipeline.retrieval_pipeline.reranker = None # type: ignore
|
|
||||||
|
|
||||||
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
|
pipeline.answering_pipeline.llm = llms.get_highest_accuracy()
|
||||||
kwargs = {
|
|
||||||
".retrieval_pipeline.top_k": int(settings["index.num_retrieval"]),
|
|
||||||
".retrieval_pipeline.mmr": settings["index.mmr"],
|
|
||||||
".retrieval_pipeline.doc_ids": kwargs.get("files", None),
|
|
||||||
}
|
|
||||||
pipeline.set_run(kwargs, temp=True)
|
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in New Issue
Block a user