Provide type hints for pass-through Langchain and Llama-index objects (#95)

This commit is contained in:
Duc Nguyen (john)
2023-12-04 10:59:13 +07:00
committed by GitHub
parent e34b1e4c6d
commit 0ce3a8832f
34 changed files with 641 additions and 310 deletions

View File

@@ -46,20 +46,48 @@ class LlamaIndexDocTransformerMixin:
"Please return the relevant LlamaIndex class in _get_li_class"
)
def __init__(self, *args, **kwargs):
_li_cls = self._get_li_class()
self._obj = _li_cls(*args, **kwargs)
def __init__(self, **params):
self._li_cls = self._get_li_class()
self._obj = self._li_cls(**params)
self._kwargs = params
super().__init__()
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = repr(value_obj)
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __str__(self):
kwargs = []
for key, value_obj in self._kwargs.items():
value = str(value_obj)
if len(value) > 20:
value = f"{value[:15]}..."
kwargs.append(f"{key}={value}")
kwargs_repr = ", ".join(kwargs)
return f"{self.__class__.__name__}({kwargs_repr})"
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in self._protected_keywords():
return super().__setattr__(name, value)
self._kwargs[name] = value
return setattr(self._obj, name, value)
def __getattr__(self, name: str) -> Any:
if name in self._kwargs:
return self._kwargs[name]
return getattr(self._obj, name)
def dump(self):
return {
"__type__": f"{self.__module__}.{self.__class__.__qualname__}",
**self._kwargs,
}
def run(
self,
documents: list[Document],

View File

@@ -6,6 +6,14 @@ class BaseDocParser(DocTransformer):
class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
def __init__(
self,
llm=None,
nodes: int = 5,
**params,
):
super().__init__(llm=llm, nodes=nodes, **params)
def _get_li_class(self):
from llama_index.extractors import TitleExtractor
@@ -13,6 +21,14 @@ class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
def __init__(
self,
llm=None,
summaries: list[str] = ["self"],
**params,
):
super().__init__(llm=llm, summaries=summaries, **params)
def _get_li_class(self):
from llama_index.extractors import SummaryExtractor

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import abstractmethod
from ...base import BaseComponent, Document
from kotaemon.base import BaseComponent, Document
class BaseReranking(BaseComponent):

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
import os
from ...base import Document
from kotaemon.base import Document
from .base import BaseReranking

View File

@@ -1,17 +1,13 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from typing import Union
from langchain.output_parsers.boolean import BooleanOutputParser
from ...base import Document
from ...llms import PromptTemplate
from ...llms.chats.base import ChatLLM
from ...llms.completions.base import LLM
from .base import BaseReranking
from kotaemon.base import Document
from kotaemon.llms import BaseLLM, PromptTemplate
BaseLLM = Union[ChatLLM, LLM]
from .base import BaseReranking
RERANK_PROMPT_TEMPLATE = """Given the following question and context,
return YES if the context is relevant to the question and NO if it isn't.

View File

@@ -8,6 +8,20 @@ class BaseSplitter(DocTransformer):
class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 20,
separator: str = " ",
**params,
):
super().__init__(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separator=separator,
**params,
)
def _get_li_class(self):
from llama_index.text_splitter import TokenTextSplitter
@@ -15,6 +29,9 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(self, window_size: int = 3, **params):
super().__init__(window_size=window_size, **params)
def _get_li_class(self):
from llama_index.node_parser import SentenceWindowNodeParser