Move LLM-related components into LLM module (#74)

* Move splitter into indexing module
* Rename post_processing module to parsers
* Migrate LLM-specific composite pipelines into llms module

This change moves the `splitters` module into `indexing` module. The `indexing` module will be created soon, to house `indexing`-related components.

This change renames `post_processing` module into `parsers` module. Post-processing is a generic term which provides very little information. In the future, we will add other extractors into the `parser` module, like Metadata extractor...

This change migrates the composite elements into `llms` module. These elements heavily assume that the internal nodes are llm-specific. As a result, migrating these elements into `llms` module will make them more discoverable, and simplify code base structure.
This commit is contained in:
Nguyen Trung Duc (john)
2023-11-15 16:26:53 +07:00
committed by GitHub
parent 9945afdf6f
commit f8b8d86d4e
13 changed files with 41 additions and 35 deletions

View File

@@ -0,0 +1,3 @@
from .regex_extractor import FirstMatchRegexExtractor, RegexExtractor
__all__ = ["RegexExtractor", "FirstMatchRegexExtractor"]

View File

@@ -0,0 +1,150 @@
from __future__ import annotations
import re
from typing import Callable
from theflow import Param
from kotaemon.base import BaseComponent, Document
from kotaemon.base.schema import ExtractorOutput
class RegexExtractor(BaseComponent):
"""
Simple class for extracting text from a document using a regex pattern.
Args:
pattern (List[str]): The regex pattern(s) to use.
output_map (dict, optional): A mapping from extracted text to the
desired output. Defaults to None.
"""
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
pattern: list[str]
output_map: dict[str, str] | Callable[[str], str] = Param(
default_callback=lambda *_: {}
)
def __init__(self, pattern: str | list[str], **kwargs):
if isinstance(pattern, str):
pattern = [pattern]
super().__init__(pattern=pattern, **kwargs)
@staticmethod
def run_raw_static(pattern: str, text: str) -> list[str]:
"""
Finds all non-overlapping occurrences of a pattern in a string.
Parameters:
pattern (str): The regular expression pattern to search for.
text (str): The input string to search in.
Returns:
List[str]: A list of all non-overlapping occurrences of the pattern in the
string.
"""
return re.findall(pattern, text)
@staticmethod
def map_output(text, output_map) -> str:
"""
Maps the given `text` to its corresponding value in the `output_map` dictionary.
Parameters:
text (str): The input text to be mapped.
output_map (dict): A dictionary containing mapping of input text to output
values.
Returns:
str: The corresponding value from the `output_map` if `text` is found in the
dictionary, otherwise returns the original `text`.
"""
if not output_map:
return text
if isinstance(output_map, dict):
return output_map.get(text, text)
return output_map(text)
def run_raw(self, text: str) -> ExtractorOutput:
"""
Matches the raw text against the pattern and rans the output mapping, returning
an instance of ExtractorOutput.
Args:
text (str): The raw text to be processed.
Returns:
ExtractorOutput: The processed output as a list of ExtractorOutput.
"""
output: list[str] = sum(
[self.run_raw_static(p, text) for p in self.pattern], []
)
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0] if output else "",
matches=output,
metadata={"origin": "RegexExtractor"},
)
def run(
self, text: str | list[str] | Document | list[Document]
) -> list[ExtractorOutput]:
"""Match the input against a pattern and return the output for each input
Parameters:
text: contains the input string to be processed
Returns:
A list contains the output ExtractorOutput for each input
Example:
document1 = Document(...)
document2 = Document(...)
document_batch = [document1, document2]
batch_output = self(document_batch)
# batch_output will be [output1_document1, output1_document2]
"""
# TODO: this conversion seems common
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
output = []
for each_input in input_:
output.append(self.run_raw(each_input))
return output
class FirstMatchRegexExtractor(RegexExtractor):
pattern: list[str]
def run_raw(self, text: str) -> ExtractorOutput:
for p in self.pattern:
output = self.run_raw_static(p, text)
if output:
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0],
matches=output,
metadata={"origin": "FirstMatchRegexExtractor"},
)
return ExtractorOutput(
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
)

View File

@@ -1,70 +0,0 @@
from typing import Any, List, Sequence, Type
from llama_index.node_parser import (
SentenceWindowNodeParser as LISentenceWindowNodeParser,
)
from llama_index.node_parser import SimpleNodeParser as LISimpleNodeParser
from llama_index.node_parser.interface import NodeParser
from llama_index.text_splitter import TokenTextSplitter
from ..base import BaseComponent, Document
__all__ = ["TokenTextSplitter"]
class LINodeParser(BaseComponent):
_parser_class: Type[NodeParser]
def __init__(self, *args, **kwargs):
if self._parser_class is None:
raise AttributeError(
"Require `_parser_class` to set a NodeParser class from LlamarIndex"
)
self._parser = self._parser_class(*args, **kwargs)
super().__init__()
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in self._protected_keywords():
return super().__setattr__(name, value)
return setattr(self._parser, name, value)
def __getattr__(self, name: str) -> Any:
return getattr(self._parser, name)
def get_nodes_from_documents(
self,
documents: Sequence[Document],
show_progress: bool = False,
) -> List[Document]:
documents = self._parser.get_nodes_from_documents(
documents=documents, show_progress=show_progress
)
# convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents
def run(
self,
documents: Sequence[Document],
show_progress: bool = False,
) -> List[Document]:
return self.get_nodes_from_documents(
documents=documents, show_progress=show_progress
)
class SimpleNodeParser(LINodeParser):
_parser_class = LISimpleNodeParser
def __init__(self, *args, **kwargs):
chunk_size = kwargs.pop("chunk_size", 512)
chunk_overlap = kwargs.pop("chunk_overlap", 0)
kwargs["text_splitter"] = TokenTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
super().__init__(*args, **kwargs)
class SentenceWindowNodeParser(LINodeParser):
_parser_class = LISentenceWindowNodeParser