Separate rerankers, splitters and extractors (#85)

This commit is contained in:
Nguyen Trung Duc (john)
2023-11-27 14:25:54 +07:00
committed by GitHub
parent 0dede9c82d
commit 2186c5558f
15 changed files with 211 additions and 135 deletions

View File

View File

@@ -0,0 +1,72 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Sequence, Type
from llama_index.node_parser.interface import NodeParser
from ..base import BaseComponent, Document
class DocTransformer(BaseComponent):
"""This is a base class for document transformers
A document transformer transforms a list of documents into another list
of documents. Transforming can mean splitting a document into multiple documents,
reducing a large list of documents into a smaller list of documents, or adding
metadata to each document in a list of documents, etc.
"""
@abstractmethod
def run(
self,
documents: Sequence[Document],
**kwargs,
) -> Sequence[Document]:
...
class LlamaIndexMixin:
"""Allow automatically wrapping a Llama-index component into kotaemon component
Example:
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter
return TokenTextSplitter
To use this mixin, please:
1. Use this class as the 1st parent class, so that Python will prefer to use
the attributes and methods of this class whenever possible.
2. Overwrite `_get_li_class` to return the relevant LlamaIndex component.
"""
def _get_li_class(self) -> Type[NodeParser]:
raise NotImplementedError(
"Please return the relevant LlamaIndex class in _get_li_class"
)
def __init__(self, *args, **kwargs):
_li_cls = self._get_li_class()
self._obj = _li_cls(*args, **kwargs)
super().__init__()
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in self._protected_keywords():
return super().__setattr__(name, value)
return setattr(self._obj, name, value)
def __getattr__(self, name: str) -> Any:
return getattr(self._obj, name)
def run(
self,
documents: Sequence[Document],
**kwargs,
) -> Sequence[Document]:
"""Run Llama-index node parser and convert the output to Document from
kotaemon
"""
docs = self._obj(documents, **kwargs) # type: ignore
return [Document.from_dict(doc.to_dict()) for doc in docs]

View File

@@ -0,0 +1,7 @@
from .doc_parsers import BaseDocParser, SummaryExtractor, TitleExtractor
__all__ = [
"BaseDocParser",
"TitleExtractor",
"SummaryExtractor",
]

View File

@@ -0,0 +1,19 @@
from ..base import DocTransformer, LlamaIndexMixin
class BaseDocParser(DocTransformer):
...
class TitleExtractor(LlamaIndexMixin, BaseDocParser):
def _get_li_class(self):
from llama_index.extractors import TitleExtractor
return TitleExtractor
class SummaryExtractor(LlamaIndexMixin, BaseDocParser):
def _get_li_class(self):
from llama_index.extractors import SummaryExtractor
return SummaryExtractor

View File

@@ -0,0 +1,5 @@
from .base import BaseReranking
from .cohere import CohereReranking
from .llm import LLMReranking
__all__ = ["CohereReranking", "LLMReranking", "BaseReranking"]

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from abc import abstractmethod
from ...base import BaseComponent, Document
class BaseReranking(BaseComponent):
@abstractmethod
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Main method to transform list of documents
(re-ranking, filtering, etc)"""
...

View File

@@ -0,0 +1,38 @@
from __future__ import annotations
import os
from ...base import Document
from .base import BaseReranking
class CohereReranking(BaseReranking):
model_name: str = "rerank-multilingual-v2.0"
cohere_api_key: str = os.environ.get("COHERE_API_KEY", "")
top_k: int = 1
def run(self, documents: list[Document], query: str) -> list[Document]:
"""Use Cohere Reranker model to re-order documents
with their relevance score"""
try:
import cohere
except ImportError:
raise ImportError(
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
)
cohere_client = cohere.Client(self.cohere_api_key)
# output documents
compressed_docs = []
if len(documents) > 0: # to avoid empty api call
_docs = [d.content for d in documents]
results = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
)
for r in results:
doc = documents[r.index]
doc.metadata["relevance_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@@ -0,0 +1,70 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from typing import Union
from langchain.output_parsers.boolean import BooleanOutputParser
from ...base import Document
from ...llms import PromptTemplate
from ...llms.chats.base import ChatLLM
from ...llms.completions.base import LLM
from .base import BaseReranking
BaseLLM = Union[ChatLLM, LLM]
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
return YES if the context is relevant to the question and NO if it isn't.
> Question: {question}
> Context:
>>>
{context}
>>>
> Relevant (YES / NO):"""
class LLMReranking(BaseReranking):
llm: BaseLLM
prompt_template: PromptTemplate = PromptTemplate(template=RERANK_PROMPT_TEMPLATE)
top_k: int = 3
concurrent: bool = True
def run(
self,
documents: list[Document],
query: str,
) -> list[Document]:
"""Filter down documents based on their relevance to the query."""
filtered_docs = []
output_parser = BooleanOutputParser()
if self.concurrent:
with ThreadPoolExecutor() as executor:
futures = []
for doc in documents:
_prompt = self.prompt_template.populate(
question=query, context=doc.get_content()
)
futures.append(executor.submit(lambda: self.llm(_prompt).text))
results = [future.result() for future in futures]
else:
results = []
for doc in documents:
_prompt = self.prompt_template.populate(
question=query, context=doc.get_content()
)
results.append(self.llm(_prompt).text)
# use Boolean parser to extract relevancy output from LLM
results = [output_parser.parse(result) for result in results]
for include_doc, doc in zip(results, documents):
if include_doc:
filtered_docs.append(doc)
# prevent returning empty result
if len(filtered_docs) == 0:
filtered_docs = documents[: self.top_k]
return filtered_docs

View File

@@ -0,0 +1,21 @@
from ..base import DocTransformer, LlamaIndexMixin
class BaseSplitter(DocTransformer):
"""Represent base splitter class"""
...
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter
return TokenTextSplitter
class SentenceWindowSplitter(LlamaIndexMixin, BaseSplitter):
def _get_li_class(self):
from llama_index.node_parser import SentenceWindowNodeParser
return SentenceWindowNodeParser