kotaemon/knowledgehub/parsers/regex_extractor.py
Nguyen Trung Duc (john) f8b8d86d4e Move LLM-related components into LLM module (#74)
* Move splitter into indexing module
* Rename post_processing module to parsers
* Migrate LLM-specific composite pipelines into llms module

This change moves the `splitters` module into `indexing` module. The `indexing` module will be created soon, to house `indexing`-related components.

This change renames `post_processing` module into `parsers` module. Post-processing is a generic term which provides very little information. In the future, we will add other extractors into the `parser` module, like Metadata extractor...

This change migrates the composite elements into `llms` module. These elements heavily assume that the internal nodes are llm-specific. As a result, migrating these elements into `llms` module will make them more discoverable, and simplify code base structure.
2023-11-15 16:26:53 +07:00

151 lines
4.7 KiB
Python

from __future__ import annotations
import re
from typing import Callable
from theflow import Param
from kotaemon.base import BaseComponent, Document
from kotaemon.base.schema import ExtractorOutput
class RegexExtractor(BaseComponent):
"""
Simple class for extracting text from a document using a regex pattern.
Args:
pattern (List[str]): The regex pattern(s) to use.
output_map (dict, optional): A mapping from extracted text to the
desired output. Defaults to None.
"""
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
pattern: list[str]
output_map: dict[str, str] | Callable[[str], str] = Param(
default_callback=lambda *_: {}
)
def __init__(self, pattern: str | list[str], **kwargs):
if isinstance(pattern, str):
pattern = [pattern]
super().__init__(pattern=pattern, **kwargs)
@staticmethod
def run_raw_static(pattern: str, text: str) -> list[str]:
"""
Finds all non-overlapping occurrences of a pattern in a string.
Parameters:
pattern (str): The regular expression pattern to search for.
text (str): The input string to search in.
Returns:
List[str]: A list of all non-overlapping occurrences of the pattern in the
string.
"""
return re.findall(pattern, text)
@staticmethod
def map_output(text, output_map) -> str:
"""
Maps the given `text` to its corresponding value in the `output_map` dictionary.
Parameters:
text (str): The input text to be mapped.
output_map (dict): A dictionary containing mapping of input text to output
values.
Returns:
str: The corresponding value from the `output_map` if `text` is found in the
dictionary, otherwise returns the original `text`.
"""
if not output_map:
return text
if isinstance(output_map, dict):
return output_map.get(text, text)
return output_map(text)
def run_raw(self, text: str) -> ExtractorOutput:
"""
Matches the raw text against the pattern and rans the output mapping, returning
an instance of ExtractorOutput.
Args:
text (str): The raw text to be processed.
Returns:
ExtractorOutput: The processed output as a list of ExtractorOutput.
"""
output: list[str] = sum(
[self.run_raw_static(p, text) for p in self.pattern], []
)
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0] if output else "",
matches=output,
metadata={"origin": "RegexExtractor"},
)
def run(
self, text: str | list[str] | Document | list[Document]
) -> list[ExtractorOutput]:
"""Match the input against a pattern and return the output for each input
Parameters:
text: contains the input string to be processed
Returns:
A list contains the output ExtractorOutput for each input
Example:
document1 = Document(...)
document2 = Document(...)
document_batch = [document1, document2]
batch_output = self(document_batch)
# batch_output will be [output1_document1, output1_document2]
"""
# TODO: this conversion seems common
input_: list[str] = []
if not isinstance(text, list):
text = [text]
for item in text:
if isinstance(item, str):
input_.append(item)
elif isinstance(item, Document):
input_.append(item.text)
else:
raise ValueError(
f"Invalid input type {type(item)}, should be str or Document"
)
output = []
for each_input in input_:
output.append(self.run_raw(each_input))
return output
class FirstMatchRegexExtractor(RegexExtractor):
pattern: list[str]
def run_raw(self, text: str) -> ExtractorOutput:
for p in self.pattern:
output = self.run_raw_static(p, text)
if output:
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0],
matches=output,
metadata={"origin": "FirstMatchRegexExtractor"},
)
return ExtractorOutput(
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
)