diff --git a/knowledgehub/documents/base.py b/knowledgehub/documents/base.py index 30f540d..579031b 100644 --- a/knowledgehub/documents/base.py +++ b/knowledgehub/documents/base.py @@ -22,6 +22,9 @@ class Document(BaseDocument): text = self.text return HaystackDocument(content=text, meta=metadata) + def __str__(self): + return self.text + class RetrievedDocument(Document): """Subclass of Document with retrieval-related information diff --git a/knowledgehub/post_processing/extractor.py b/knowledgehub/post_processing/extractor.py index 3d56a71..61b6254 100644 --- a/knowledgehub/post_processing/extractor.py +++ b/knowledgehub/post_processing/extractor.py @@ -6,7 +6,8 @@ from kotaemon.documents.base import Document class RegexExtractor(BaseComponent): - """Simple class for extracting text from a document using a regex pattern. + """ + Simple class for extracting text from a document using a regex pattern. Args: pattern (str): The regex pattern to use. diff --git a/knowledgehub/prompt/__init__.py b/knowledgehub/prompt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/knowledgehub/prompt/base.py b/knowledgehub/prompt/base.py new file mode 100644 index 0000000..ee40be5 --- /dev/null +++ b/knowledgehub/prompt/base.py @@ -0,0 +1,184 @@ +from typing import Union + +from kotaemon.base import BaseComponent +from kotaemon.documents.base import Document +from kotaemon.prompt.template import PromptTemplate + + +class BasePrompt(BaseComponent): + """ + Base class for prompt components. + + Args: + template (PromptTemplate): The prompt template. + **kwargs: Any additional keyword arguments that will be used to populate the + given template. + """ + + def __check_redundant_kwargs(self, **kwargs): + """ + Check for redundant keyword arguments. + + Parameters: + **kwargs (dict): A dictionary of keyword arguments. + + Raises: + ValueError: If any keys provided are not in the template. + + Returns: + None + """ + provided_keys = set(kwargs.keys()) + expected_keys = self.template.placeholders + + redundant_keys = provided_keys - expected_keys + if redundant_keys: + raise ValueError(f"\nKeys provided but not in template: {redundant_keys}") + + def __check_unset_placeholders(self): + """ + Check if all the placeholders in the template are set. + + This function checks if all the expected placeholders in the template are set as + attributes of the object. If any placeholders are missing, a `ValueError` + is raised with the names of the missing keys. + + Parameters: + None + + Returns: + None + """ + expected_keys = self.template.placeholders + + missing_keys = [] + for key in expected_keys: + if key not in self.__dict__: + missing_keys.append(key) + + if missing_keys: + raise ValueError(f"\nMissing keys in template: {missing_keys}") + + def __validate_value_type(self, **kwargs): + """ + Validates the value types of the given keyword arguments. + + Parameters: + **kwargs (dict): A dictionary of keyword arguments to be validated. + + Raises: + ValueError: If any of the values in the kwargs dictionary have an + unsupported type. + + Returns: + None + """ + type_error = [] + for k, v in kwargs.items(): + if not isinstance(v, (str, int, Document, BaseComponent)): + if isinstance(v, int): + kwargs[k] = str(v) + type_error.append((k, type(v))) + + if type_error: + raise ValueError( + "Type of values must be either int, str, Document, BaseComponent, " + f"found unsupported type for (key, type): {type_error}" + ) + + def __set(self, **kwargs): + """ + Set the values of the attributes in the object based on the provided keyword + arguments. + + Args: + kwargs (dict): A dictionary with the attribute names as keys and the new + values as values. + + Returns: + None + """ + self.__check_redundant_kwargs(**kwargs) + self.__validate_value_type(**kwargs) + + self.__dict__.update(kwargs) + + def __prepare_value(self): + """ + Generate a dictionary of keyword arguments based on the template's placeholders + and the current instance's attributes. + + Returns: + dict: A dictionary of keyword arguments. + """ + kwargs = {} + for k in self.template.placeholders: + v = getattr(self, k) + if isinstance(v, (int, Document)): + v = str(v) + elif isinstance(v, BaseComponent): + v = str(v()) + kwargs[k] = v + + return kwargs + + def __init__(self, template: Union[str, PromptTemplate], **kwargs): + super().__init__() + self.template = ( + template + if isinstance(template, PromptTemplate) + else PromptTemplate(template) + ) + + self.__set(**kwargs) + + def set(self, **kwargs): + """ + Similar to `__set` but for external use. + + Set the values of the attributes in the object based on the provided keyword + arguments. + + Args: + kwargs (dict): A dictionary with the attribute names as keys and the new + values as values. + + Returns: + None + """ + self.__set(**kwargs) + + def run(self, **kwargs): + """ + Run the function with the given keyword arguments. + + Args: + **kwargs: The keyword arguments to pass to the function. + + Returns: + The result of calling the `populate` method of the `template` object + with the given keyword arguments. + """ + self.__set(**kwargs) + self.__check_unset_placeholders() + prepared_kwargs = self.__prepare_value() + + return self.template.populate(**prepared_kwargs) + + def run_raw(self, *args, **kwargs): + pass + + def run_batch_raw(self, *args, **kwargs): + pass + + def run_document(self, *args, **kwargs): + pass + + def run_batch_document(self, *args, **kwargs): + pass + + def is_document(self, *args, **kwargs): + pass + + def is_batch(self, *args, **kwargs): + pass diff --git a/knowledgehub/prompt/template.py b/knowledgehub/prompt/template.py new file mode 100644 index 0000000..03b79bc --- /dev/null +++ b/knowledgehub/prompt/template.py @@ -0,0 +1,68 @@ +import re +from typing import Set + + +class PromptTemplate: + """ + Base class for prompt templates. + """ + + @staticmethod + def extract_placeholders(template: str) -> Set[str]: + """ + Extracts placeholders from a template string. + + Args: + template (str): The template string to extract placeholders from. + + Returns: + set[str]: A set of placeholder names found in the template string. + """ + placeholder_regex = r"{([a-zA-Z_][a-zA-Z0-9_]*)}" + + placeholders = set() + for item in re.findall(placeholder_regex, template): + if item.isidentifier(): + placeholders.add(item) + + return placeholders + + def __init__(self, template: str): + self.placeholders = self.extract_placeholders(template) + self.template = template + + def populate(self, **kwargs): + """ + Populate the template with the given keyword arguments. + + Args: + **kwargs: The keyword arguments to populate the template. + Each keyword corresponds to a placeholder in the template. + + Returns: + str: The populated template. + + Raises: + ValueError: If an unknown placeholder is provided. + + """ + prompt = self.template + for placeholder, value in kwargs.items(): + if placeholder not in self.placeholders: + raise ValueError(f"Unknown placeholder: {placeholder}") + prompt = prompt.replace(f"{{{placeholder}}}", value) + + return prompt + + def __add__(self, other): + """ + Create a new PromptTemplate object by concatenating the template of the current + object with the template of another PromptTemplate object. + + Parameters: + other (PromptTemplate): Another PromptTemplate object. + + Returns: + PromptTemplate: A new PromptTemplate object with the concatenated templates. + """ + return PromptTemplate(self.template + "\n" + other.template) diff --git a/tests/test_prompt.py b/tests/test_prompt.py new file mode 100644 index 0000000..b55d42b --- /dev/null +++ b/tests/test_prompt.py @@ -0,0 +1,67 @@ +import pytest + +from kotaemon.documents.base import Document +from kotaemon.post_processing.extractor import RegexExtractor +from kotaemon.prompt.base import BasePrompt +from kotaemon.prompt.template import PromptTemplate + + +def test_set_attributes(): + template = PromptTemplate("str = {s}, int = {i}, doc = {doc}, comp = {comp}") + doc = Document(text="Helloo, Alice!") + comp = RegexExtractor( + pattern=r"\d+", output_map={"1": "One", "2": "Two", "3": "Three"} + ) + comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True) + + prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp) + assert prompt.s == "Alice" + assert prompt.i == 30 + assert prompt.doc == doc + assert prompt.comp == comp + + +def test_check_redundant_kwargs(): + template = PromptTemplate("Hello, {name}!") + prompt = BasePrompt(template, name="Alice") + with pytest.raises(ValueError): + prompt._BasePrompt__check_redundant_kwargs(name="Alice", age=30) + + +def test_check_unset_placeholders(): + template = PromptTemplate("Hello, {name}! I'm {age} years old.") + prompt = BasePrompt(template, name="Alice") + with pytest.raises(ValueError): + prompt._BasePrompt__check_unset_placeholders() + + +def test_validate_value_type(): + template = PromptTemplate("Hello, {name}!") + prompt = BasePrompt(template) + with pytest.raises(ValueError): + prompt._BasePrompt__validate_value_type(name={}) + + +def test_run(): + template = PromptTemplate("str = {s}, int = {i}, doc = {doc}, comp = {comp}") + doc = Document(text="Helloo, Alice!") + comp = RegexExtractor( + pattern=r"\d+", output_map={"1": "One", "2": "Two", "3": "Three"} + ) + comp.set_run(kwargs={"text": "This is a test. 1 2 3"}, temp=True) + + prompt = BasePrompt(template=template, s="Alice", i=30, doc=doc, comp=comp) + + result = prompt() + + assert ( + result + == "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']" + ) + + +def test_set_method(): + template = PromptTemplate("Hello, {name}!") + prompt = BasePrompt(template) + prompt.set(name="Alice") + assert prompt.name == "Alice" diff --git a/tests/test_template.py b/tests/test_template.py new file mode 100644 index 0000000..96a3d74 --- /dev/null +++ b/tests/test_template.py @@ -0,0 +1,53 @@ +import pytest + +from kotaemon.prompt.template import PromptTemplate + + +def test_prompt_template_creation(): + # Test case 1: Ensure the PromptTemplate object is created correctly + template_string = "This is a template" + template = PromptTemplate(template_string) + assert template.template == template_string + + template_string = "Hello, {name}! Today is {day}." + template = PromptTemplate(template_string) + assert template.template == template_string + assert template.placeholders == {"name", "day"} + + +def test_prompt_template_addition(): + # Test case 2: Ensure the __add__ method concatenates the templates correctly + template1 = PromptTemplate("Hello, ") + template2 = PromptTemplate("world!") + result = template1 + template2 + assert result.template == "Hello, \nworld!" + + template1 = PromptTemplate("Hello, {name}!") + template2 = PromptTemplate("Today is {day}.") + result = template1 + template2 + assert result.template == "Hello, {name}!\nToday is {day}." + + +def test_prompt_template_extract_placeholders(): + # Test case 3: Ensure the extract_placeholders method extracts placeholders + # correctly + template_string = "Hello, {name}! Today is {day}." + result = PromptTemplate.extract_placeholders(template_string) + assert result == {"name", "day"} + + +def test_prompt_template_populate(): + # Test case 4: Ensure the populate method populates the template correctly + template_string = "Hello, {name}! Today is {day}." + template = PromptTemplate(template_string) + result = template.populate(name="John", day="Monday") + assert result == "Hello, John! Today is Monday." + + +def test_prompt_template_unknown_placeholder(): + # Test case 5: Ensure the populate method raises an exception for unknown + # placeholders + template_string = "Hello, {name}! Today is {day}." + template = PromptTemplate(template_string) + with pytest.raises(ValueError): + template.populate(name="John", month="January") diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py index 8b3556c..98d34d0 100644 --- a/tests/test_vectorstore.py +++ b/tests/test_vectorstore.py @@ -56,7 +56,7 @@ class TestChromaVectorStore: db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) _, sim, out_ids = db.query(embedding=[0.1, 0.2, 0.3], top_k=1) - assert sim == [0.0] + assert sim == [1.0] assert out_ids == ["a"] _, _, out_ids = db.query(embedding=[0.42, 0.52, 0.53], top_k=1)