Separate rerankers, splitters and extractors (#85)
This commit is contained in:
committed by
GitHub
parent
0dede9c82d
commit
2186c5558f
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from ..base import BaseComponent, Document
|
||||
from ..embeddings import BaseEmbeddings
|
||||
@@ -22,7 +22,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||
"""
|
||||
|
||||
vector_store: BaseVectorStore
|
||||
doc_store: BaseDocumentStore
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
embedding: BaseEmbeddings
|
||||
# TODO: refer to llama_index's storage as well
|
||||
|
||||
@@ -64,7 +64,8 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.save(path / vectorstore_fname)
|
||||
self.doc_store.save(path / docstore_fname)
|
||||
if self.doc_store:
|
||||
self.doc_store.save(path / docstore_fname)
|
||||
|
||||
def load(
|
||||
self,
|
||||
@@ -76,4 +77,5 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.load(path / vectorstore_fname)
|
||||
self.doc_store.load(path / docstore_fname)
|
||||
if self.doc_store:
|
||||
self.doc_store.load(path / docstore_fname)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from theflow import Node
|
||||
@@ -8,8 +10,9 @@ from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.indexing.doc_parsers import LIDocParser as DocParser
|
||||
from kotaemon.indexing.doc_parsers import TokenSplitter
|
||||
from kotaemon.indices.extractors import BaseDocParser
|
||||
from kotaemon.indices.rankings import BaseReranking
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
from kotaemon.loaders import (
|
||||
AutoReader,
|
||||
DirectoryReader,
|
||||
@@ -19,7 +22,6 @@ from kotaemon.loaders import (
|
||||
)
|
||||
from kotaemon.pipelines.agents import BaseAgent
|
||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||
from kotaemon.pipelines.reranking import BaseRerankingPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.storages import (
|
||||
BaseDocumentStore,
|
||||
@@ -45,7 +47,7 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||
chunk_overlap: int = 256
|
||||
vector_store: BaseVectorStore = _(InMemoryVectorStore)
|
||||
doc_store: BaseDocumentStore = _(InMemoryDocumentStore)
|
||||
doc_parsers: List[DocParser] = []
|
||||
doc_parsers: list[BaseDocParser] = []
|
||||
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
@@ -55,9 +57,9 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||
chunk_size=16,
|
||||
)
|
||||
|
||||
def get_reader(self, input_files: List[Union[str, Path]]):
|
||||
def get_reader(self, input_files: list[str | Path]):
|
||||
# document parsers
|
||||
file_extractor: Dict[str, BaseReader] = {
|
||||
file_extractor: dict[str, BaseReader | AutoReader] = {
|
||||
".xlsx": PandasExcelReader(),
|
||||
}
|
||||
if self.reader_name == "normal":
|
||||
@@ -89,7 +91,7 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
file_path_list: Union[List[Union[str, Path]], Union[str, Path]],
|
||||
file_path_list: list[str | Path] | str | Path,
|
||||
force_reindex: Optional[bool] = False,
|
||||
):
|
||||
self.storage_path.mkdir(exist_ok=True)
|
||||
@@ -121,9 +123,7 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||
else:
|
||||
self.indexing_vector_pipeline.load(file_storage_path)
|
||||
|
||||
def to_retrieving_pipeline(
|
||||
self, top_k=3, rerankers: Sequence[BaseRerankingPipeline] = []
|
||||
):
|
||||
def to_retrieving_pipeline(self, top_k=3, rerankers: Sequence[BaseReranking] = []):
|
||||
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||
vector_store=self.vector_store,
|
||||
doc_store=self.doc_store,
|
||||
@@ -141,7 +141,7 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
llm=llm,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
return qa_pipeline
|
||||
|
||||
@@ -153,7 +153,7 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||
doc_store=self.doc_store,
|
||||
embedding=self.embedding,
|
||||
agent=agent,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
agent_pipeline.add_search_tool()
|
||||
return agent_pipeline
|
||||
|
||||
@@ -8,11 +8,11 @@ from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import Document, RetrievedDocument
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.indices.rankings import BaseReranking
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.pipelines.agents import BaseAgent
|
||||
from kotaemon.pipelines.citation import CitationPipeline
|
||||
from kotaemon.pipelines.reranking import BaseRerankingPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.pipelines.tools import ComponentTool
|
||||
from kotaemon.storages import (
|
||||
@@ -51,7 +51,7 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
|
||||
vector_store: BaseVectorStore = _(InMemoryVectorStore)
|
||||
doc_store: BaseDocumentStore = _(InMemoryDocumentStore)
|
||||
rerankers: Sequence[BaseRerankingPipeline] = []
|
||||
rerankers: Sequence[BaseReranking] = []
|
||||
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
|
||||
from ..base import BaseComponent
|
||||
from ..base.schema import Document
|
||||
from ..llms import PromptTemplate
|
||||
from ..llms.chats.base import ChatLLM
|
||||
from ..llms.completions.base import LLM
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
class BaseRerankingPipeline(BaseComponent):
|
||||
@abstractmethod
|
||||
def run(self, documents: List[Document], query: str) -> List[Document]:
|
||||
"""Main method to transform list of documents
|
||||
(re-ranking, filtering, etc)"""
|
||||
...
|
||||
|
||||
|
||||
class CohereReranking(BaseRerankingPipeline):
|
||||
model_name: str = "rerank-multilingual-v2.0"
|
||||
cohere_api_key: Optional[str] = None
|
||||
top_k: int = 1
|
||||
|
||||
def run(self, documents: List[Document], query: str) -> List[Document]:
|
||||
"""Use Cohere Reranker model to re-order documents
|
||||
with their relevance score"""
|
||||
try:
|
||||
import cohere
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||
)
|
||||
|
||||
cohere_api_key = (
|
||||
self.cohere_api_key if self.cohere_api_key else os.environ["COHERE_API_KEY"]
|
||||
)
|
||||
cohere_client = cohere.Client(cohere_api_key)
|
||||
|
||||
# output documents
|
||||
compressed_docs = []
|
||||
if len(documents) > 0: # to avoid empty api call
|
||||
_docs = [d.content for d in documents]
|
||||
results = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||
)
|
||||
for r in results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
||||
|
||||
|
||||
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
|
||||
return YES if the context is relevant to the question and NO if it isn't.
|
||||
|
||||
> Question: {question}
|
||||
> Context:
|
||||
>>>
|
||||
{context}
|
||||
>>>
|
||||
> Relevant (YES / NO):"""
|
||||
|
||||
|
||||
class LLMReranking(BaseRerankingPipeline):
|
||||
llm: BaseLLM
|
||||
prompt_template: PromptTemplate = PromptTemplate(template=RERANK_PROMPT_TEMPLATE)
|
||||
top_k: int = 3
|
||||
concurrent: bool = True
|
||||
|
||||
def run(
|
||||
self,
|
||||
documents: List[Document],
|
||||
query: str,
|
||||
) -> List[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
output_parser = BooleanOutputParser()
|
||||
|
||||
if self.concurrent:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for doc in documents:
|
||||
_prompt = self.prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
futures.append(executor.submit(lambda: self.llm(_prompt).text))
|
||||
|
||||
results = [future.result() for future in futures]
|
||||
else:
|
||||
results = []
|
||||
for doc in documents:
|
||||
_prompt = self.prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
results.append(self.llm(_prompt).text)
|
||||
|
||||
# use Boolean parser to extract relevancy output from LLM
|
||||
results = [output_parser.parse(result) for result in results]
|
||||
for include_doc, doc in zip(results, documents):
|
||||
if include_doc:
|
||||
filtered_docs.append(doc)
|
||||
|
||||
# prevent returning empty result
|
||||
if len(filtered_docs) == 0:
|
||||
filtered_docs = documents[: self.top_k]
|
||||
|
||||
return filtered_docs
|
||||
@@ -3,11 +3,12 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from kotaemon.indices.rankings import BaseReranking
|
||||
|
||||
from ..base import BaseComponent
|
||||
from ..base.schema import Document, RetrievedDocument
|
||||
from ..embeddings import BaseEmbeddings
|
||||
from ..storages import BaseDocumentStore, BaseVectorStore
|
||||
from .reranking import BaseRerankingPipeline
|
||||
|
||||
VECTOR_STORE_FNAME = "vectorstore"
|
||||
DOC_STORE_FNAME = "docstore"
|
||||
@@ -19,7 +20,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
|
||||
vector_store: BaseVectorStore
|
||||
doc_store: BaseDocumentStore
|
||||
embedding: BaseEmbeddings
|
||||
rerankers: Sequence[BaseRerankingPipeline] = []
|
||||
rerankers: Sequence[BaseReranking] = []
|
||||
top_k: int = 1
|
||||
# TODO: refer to llama_index's storage as well
|
||||
|
||||
|
||||
Reference in New Issue
Block a user