From b7940516530011f3a5c44d83664b468b50f57517 Mon Sep 17 00:00:00 2001 From: ian_Cin Date: Tue, 19 Sep 2023 19:54:44 +0700 Subject: [PATCH] [AUR-421] base output post-processor that works using regex. (#20) --- .pre-commit-config.yaml | 1 + knowledgehub/{components.py => base.py} | 0 knowledgehub/documents/base.py | 2 +- knowledgehub/embeddings/base.py | 2 +- knowledgehub/llms/base.py | 2 +- knowledgehub/llms/chats/base.py | 2 +- knowledgehub/llms/completions/base.py | 2 +- knowledgehub/pipelines/indexing.py | 2 +- knowledgehub/pipelines/retrieving.py | 2 +- knowledgehub/post_processing/__init__.py | 0 knowledgehub/post_processing/extractor.py | 166 ++++++++++++++++++++++ tests/test_post_processing.py | 38 +++++ 12 files changed, 212 insertions(+), 7 deletions(-) rename knowledgehub/{components.py => base.py} (100%) create mode 100644 knowledgehub/post_processing/__init__.py create mode 100644 knowledgehub/post_processing/extractor.py create mode 100644 tests/test_post_processing.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d97508e..dd2dd90 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,3 +47,4 @@ repos: rev: "v1.5.1" hooks: - id: mypy + args: ["--check-untyped-defs", "--ignore-missing-imports"] diff --git a/knowledgehub/components.py b/knowledgehub/base.py similarity index 100% rename from knowledgehub/components.py rename to knowledgehub/base.py diff --git a/knowledgehub/documents/base.py b/knowledgehub/documents/base.py index d9c1981..29bbcec 100644 --- a/knowledgehub/documents/base.py +++ b/knowledgehub/documents/base.py @@ -15,7 +15,7 @@ class Document(BaseDocument): ) return document - def to_haystack_format(self) -> HaystackDocument: + def to_haystack_format(self) -> "HaystackDocument": """Convert struct to Haystack document format.""" metadata = self.metadata or {} text = self.text diff --git a/knowledgehub/embeddings/base.py b/knowledgehub/embeddings/base.py index 2c261ea..cc520dc 100644 --- a/knowledgehub/embeddings/base.py +++ b/knowledgehub/embeddings/base.py @@ -4,7 +4,7 @@ from typing import List, Type from langchain.schema.embeddings import Embeddings as LCEmbeddings from theflow import Param -from ..components import BaseComponent +from ..base import BaseComponent from ..documents.base import Document diff --git a/knowledgehub/llms/base.py b/knowledgehub/llms/base.py index db09bd9..4fefece 100644 --- a/knowledgehub/llms/base.py +++ b/knowledgehub/llms/base.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from typing import List -from ..components import BaseComponent +from ..base import BaseComponent @dataclass diff --git a/knowledgehub/llms/chats/base.py b/knowledgehub/llms/chats/base.py index 53a5b9b..fd44aad 100644 --- a/knowledgehub/llms/chats/base.py +++ b/knowledgehub/llms/chats/base.py @@ -4,7 +4,7 @@ from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import BaseMessage, HumanMessage from theflow.base import Param -from ...components import BaseComponent +from ...base import BaseComponent from ..base import LLMInterface Message = TypeVar("Message", bound=BaseMessage) diff --git a/knowledgehub/llms/completions/base.py b/knowledgehub/llms/completions/base.py index 145979e..6cbba52 100644 --- a/knowledgehub/llms/completions/base.py +++ b/knowledgehub/llms/completions/base.py @@ -3,7 +3,7 @@ from typing import List, Type from langchain.schema.language_model import BaseLanguageModel from theflow.base import Param -from ...components import BaseComponent +from ...base import BaseComponent from ..base import LLMInterface diff --git a/knowledgehub/pipelines/indexing.py b/knowledgehub/pipelines/indexing.py index 86d7cb5..a3b0c72 100644 --- a/knowledgehub/pipelines/indexing.py +++ b/knowledgehub/pipelines/indexing.py @@ -2,7 +2,7 @@ from typing import List from theflow import Node, Param -from ..components import BaseComponent +from ..base import BaseComponent from ..documents.base import Document from ..embeddings import BaseEmbeddings from ..vectorstores import BaseVectorStore diff --git a/knowledgehub/pipelines/retrieving.py b/knowledgehub/pipelines/retrieving.py index a5d1606..4ba43ba 100644 --- a/knowledgehub/pipelines/retrieving.py +++ b/knowledgehub/pipelines/retrieving.py @@ -2,7 +2,7 @@ from typing import List from theflow import Node, Param -from ..components import BaseComponent +from ..base import BaseComponent from ..documents.base import Document from ..embeddings import BaseEmbeddings from ..vectorstores import BaseVectorStore diff --git a/knowledgehub/post_processing/__init__.py b/knowledgehub/post_processing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/knowledgehub/post_processing/extractor.py b/knowledgehub/post_processing/extractor.py new file mode 100644 index 0000000..3d56a71 --- /dev/null +++ b/knowledgehub/post_processing/extractor.py @@ -0,0 +1,166 @@ +import re +from typing import Dict, List + +from kotaemon.base import BaseComponent +from kotaemon.documents.base import Document + + +class RegexExtractor(BaseComponent): + """Simple class for extracting text from a document using a regex pattern. + + Args: + pattern (str): The regex pattern to use. + output_map (dict, optional): A mapping from extracted text to the + desired output. Defaults to None. + """ + + pattern: str + output_map: Dict[str, str] = {} + + @staticmethod + def run_raw_static(pattern: str, text: str) -> List[str]: + """ + Finds all non-overlapping occurrences of a pattern in a string. + + Parameters: + pattern (str): The regular expression pattern to search for. + text (str): The input string to search in. + + Returns: + List[str]: A list of all non-overlapping occurrences of the pattern in the + string. + """ + return re.findall(pattern, text) + + @staticmethod + def map_output(text, output_map) -> str: + """ + Maps the given `text` to its corresponding value in the `output_map` dictionary. + + Parameters: + text (str): The input text to be mapped. + output_map (dict): A dictionary containing mapping of input text to output + values. + + Returns: + str: The corresponding value from the `output_map` if `text` is found in the + dictionary, otherwise returns the original `text`. + """ + if not output_map: + return text + + return output_map.get(text, text) + + def run_raw(self, text: str) -> List[str]: + """ + Runs the raw text through the static pattern and output mapping, returning a + list of strings. + + Args: + text (str): The raw text to be processed. + + Returns: + List[str]: The processed output as a list of strings. + """ + output = self.run_raw_static(self.pattern, text) + output = [self.map_output(text, self.output_map) for text in output] + + return output + + def run_batch_raw(self, text_batch: List[str]) -> List[List[str]]: + """ + Runs a batch of raw text inputs through the `run_raw()` method and returns the + output for each input. + + Parameters: + text_batch (List[str]): A list of raw text inputs to process. + + Returns: + List[List[str]]: A list of lists containing the output for each input in the + batch. + """ + batch_output = [self.run_raw(each_text) for each_text in text_batch] + + return batch_output + + def run_document(self, document: Document) -> List[Document]: + """ + Run the document through the regex extractor and return a list of extracted + documents. + + Args: + document (Document): The input document. + + Returns: + List[Document]: A list of extracted documents. + """ + texts = self.run_raw(document.text) + output = [ + Document(text=text, metadata={**document.metadata, "RegexExtractor": True}) + for text in texts + ] + + return output + + def run_batch_document( + self, document_batch: List[Document] + ) -> List[List[Document]]: + """ + Runs a batch of documents through the `run_document` function and returns the + output for each document. + + + Parameters: + document_batch (List[Document]): A list of Document objects representing the + batch of documents to process. + + Returns: + List[List[Document]]: A list of lists where each inner list contains the + output Document for each input Document in the batch. + + Example: + document1 = Document(...) + document2 = Document(...) + document_batch = [document1, document2] + batch_output = self.run_batch_document(document_batch) + # batch_output will be [[output1_document1, ...], [output1_document2, ...]] + """ + + batch_output = [ + self.run_document(each_document) for each_document in document_batch + ] + + return batch_output + + def is_document(self, text) -> bool: + """ + Check if the given text is an instance of the Document class. + + Args: + text: The text to check. + + Returns: + bool: True if the text is an instance of Document, False otherwise. + """ + if isinstance(text, Document): + return True + + return False + + def is_batch(self, text) -> bool: + """ + Check if the given text is a batch of documents. + + Parameters: + text (List): The text to be checked. + + Returns: + bool: True if the text is a batch of documents, False otherwise. + """ + if not isinstance(text, List): + return False + + if len(set(self.is_document(each_text) for each_text in text)) <= 1: + return True + + return False diff --git a/tests/test_post_processing.py b/tests/test_post_processing.py new file mode 100644 index 0000000..405426b --- /dev/null +++ b/tests/test_post_processing.py @@ -0,0 +1,38 @@ +import pytest + +from kotaemon.documents.base import Document +from kotaemon.post_processing.extractor import RegexExtractor + + +@pytest.fixture +def regex_extractor(): + return RegexExtractor( + pattern=r"\d+", output_map={"1": "One", "2": "Two", "3": "Three"} + ) + + +def test_run_document(regex_extractor): + document = Document(text="This is a test. 1 2 3") + extracted_document = regex_extractor(document) + extracted_texts = [each.text for each in extracted_document] + assert extracted_texts == ["One", "Two", "Three"] + + +def test_is_document(regex_extractor): + assert regex_extractor.is_document(Document(text="Test")) + assert not regex_extractor.is_document("Test") + + +def test_is_batch(regex_extractor): + assert regex_extractor.is_batch([Document(text="Test")]) + assert not regex_extractor.is_batch(Document(text="Test")) + + +def test_run_raw(regex_extractor): + output = regex_extractor("This is a test. 123") + assert output == ["123"] + + +def test_run_batch_raw(regex_extractor): + output = regex_extractor(["This is a test. 123", "456"]) + assert output == [["123"], ["456"]]