[Feat] Add support for f-string syntax in PromptTemplate (#38)

* Add support for f-string syntax in PromptTemplate
This commit is contained in:
ian_Cin 2023-10-04 16:40:09 +07:00 committed by GitHub
parent 56bc41b673
commit 2638152054
5 changed files with 183 additions and 66 deletions

View File

@ -50,7 +50,7 @@ pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
internally uses `gpg` to encrypt and decrypt secret files. internally uses `gpg` to encrypt and decrypt secret files.
This repo uses `python-dotenv` to manage credentials stored as enviroment variable. This repo uses `python-dotenv` to manage credentials stored as environment variable.
Please note that the use of `python-dotenv` and credentials are for development Please note that the use of `python-dotenv` and credentials are for development
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`. purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.

View File

@ -1,5 +1,4 @@
import warnings from typing import Callable, Union
from typing import Union
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document from kotaemon.documents.base import Document
@ -39,14 +38,7 @@ class BasePromptComponent(BaseComponent):
Returns: Returns:
None None
""" """
provided_keys = set(kwargs.keys()) self.template.check_redundant_kwargs(**kwargs)
expected_keys = self.template.placeholders
redundant_keys = provided_keys - expected_keys
if redundant_keys:
warnings.warn(
f"Keys provided but not in template: {redundant_keys}", UserWarning
)
def __check_unset_placeholders(self): def __check_unset_placeholders(self):
""" """
@ -62,15 +54,7 @@ class BasePromptComponent(BaseComponent):
Returns: Returns:
None None
""" """
expected_keys = self.template.placeholders self.template.check_missing_kwargs(**self.__dict__)
missing_keys = []
for key in expected_keys:
if key not in self.__dict__:
missing_keys.append(key)
if missing_keys:
raise ValueError(f"\nMissing keys in template: {missing_keys}")
def __validate_value_type(self, **kwargs): def __validate_value_type(self, **kwargs):
""" """
@ -88,14 +72,12 @@ class BasePromptComponent(BaseComponent):
""" """
type_error = [] type_error = []
for k, v in kwargs.items(): for k, v in kwargs.items():
if not isinstance(v, (str, int, Document, BaseComponent)): if not isinstance(v, (str, int, Document, Callable)): # type: ignore
if isinstance(v, int):
kwargs[k] = str(v)
type_error.append((k, type(v))) type_error.append((k, type(v)))
if type_error: if type_error:
raise ValueError( raise ValueError(
"Type of values must be either int, str, Document, BaseComponent, " "Type of values must be either int, str, Document, Callable, "
f"found unsupported type for (key, type): {type_error}" f"found unsupported type for (key, type): {type_error}"
) )
@ -138,15 +120,18 @@ class BasePromptComponent(BaseComponent):
kwargs = {} kwargs = {}
for k in self.template.placeholders: for k in self.template.placeholders:
v = getattr(self, k) v = getattr(self, k)
if isinstance(v, BaseComponent):
# if get a callable, execute to get its output
if isinstance(v, Callable): # type: ignore[arg-type]
v = v() v = v()
if isinstance(v, list): if isinstance(v, list):
v = str([__prepare(k, each) for each in v]) v = str([__prepare(k, each) for each in v])
elif isinstance(v, (str, int, Document)): elif isinstance(v, (str, int, Document)):
v = __prepare(k, v) v = __prepare(k, v)
else: else:
raise ValueError( raise ValueError(
f"Unsupported type {type(v)} for template value of key {k}" f"Unsupported type {type(v)} for template value of key `{k}`"
) )
kwargs[k] = v kwargs[k] = v

View File

@ -1,5 +1,5 @@
import re import warnings
from typing import Set from string import Formatter
class PromptTemplate: class PromptTemplate:
@ -7,33 +7,74 @@ class PromptTemplate:
Base class for prompt templates. Base class for prompt templates.
""" """
@staticmethod def __init__(self, template: str, ignore_invalid=True):
def extract_placeholders(template: str) -> Set[str]: template = template
""" formatter = Formatter()
Extracts placeholders from a template string. parsed_template = list(formatter.parse(template))
Args:
template (str): The template string to extract placeholders from.
Returns:
set[str]: A set of placeholder names found in the template string.
"""
placeholder_regex = r"{([a-zA-Z_][a-zA-Z0-9_]*)}"
placeholders = set() placeholders = set()
for item in re.findall(placeholder_regex, template): for _, key, _, _ in parsed_template:
if item.isidentifier(): if key is None:
placeholders.add(item) continue
if not key.isidentifier():
if ignore_invalid:
warnings.warn(f"Ignore invalid placeholder: {key}.", UserWarning)
else:
raise ValueError(
"Placeholder name must be a valid Python identifier, found:"
f" {key}."
)
placeholders.add(key)
return placeholders
def __init__(self, template: str):
self.placeholders = self.extract_placeholders(template)
self.template = template self.template = template
self.placeholders = placeholders
self.__formatter = formatter
self.__parsed_template = parsed_template
def check_missing_kwargs(self, **kwargs):
"""
Check if all the placeholders in the template are set.
This function checks if all the expected placeholders in the template are set as
attributes of the object. If any placeholders are missing, a `ValueError`
is raised with the names of the missing keys.
Parameters:
None
Returns:
None
"""
missing_keys = self.placeholders.difference(kwargs.keys())
if missing_keys:
raise ValueError(f"Missing keys in template: {','.join(missing_keys)}")
def check_redundant_kwargs(self, **kwargs):
"""
Check if all the placeholders in the template are set.
This function checks if all the expected placeholders in the template are set as
attributes of the object. If any placeholders are missing, a `ValueError`
is raised with the names of the missing keys.
Parameters:
None
Returns:
None
"""
provided_keys = set(kwargs.keys())
redundant_keys = provided_keys - self.placeholders
if redundant_keys:
warnings.warn(
f"Keys provided but not in template: {','.join(redundant_keys)}",
UserWarning,
)
def populate(self, **kwargs): def populate(self, **kwargs):
""" """
Populate the template with the given keyword arguments. Strictly populate the template with the given keyword arguments.
Args: Args:
**kwargs: The keyword arguments to populate the template. **kwargs: The keyword arguments to populate the template.
@ -44,15 +85,46 @@ class PromptTemplate:
Raises: Raises:
ValueError: If an unknown placeholder is provided. ValueError: If an unknown placeholder is provided.
""" """
prompt = self.template self.check_missing_kwargs(**kwargs)
for placeholder, value in kwargs.items():
if placeholder not in self.placeholders:
raise ValueError(f"Unknown placeholder: {placeholder}")
prompt = prompt.replace(f"{{{placeholder}}}", value)
return prompt return self.partial_populate(**kwargs)
def partial_populate(self, **kwargs):
"""
Partially populate the template with the given keyword arguments.
Args:
**kwargs: The keyword arguments to populate the template.
Each keyword corresponds to a placeholder in the template.
Returns:
str: The populated template.
"""
self.check_redundant_kwargs(**kwargs)
prompt = []
for literal_text, field_name, format_spec, conversion in self.__parsed_template:
prompt.append(literal_text)
if field_name is None:
continue
if field_name not in kwargs:
if conversion:
value = f"{{{field_name}}}!{conversion}:{format_spec}"
else:
value = f"{{{field_name}:{format_spec}}}"
else:
value = kwargs[field_name]
if conversion is not None:
value = self.__formatter.convert_field(value, conversion)
if format_spec is not None:
value = self.__formatter.format_field(value, format_spec)
prompt.append(value)
return "".join(prompt)
def __add__(self, other): def __add__(self, other):
""" """

View File

@ -24,7 +24,7 @@ def test_set_attributes():
def test_check_redundant_kwargs(): def test_check_redundant_kwargs():
template = PromptTemplate("Hello, {name}!") template = PromptTemplate("Hello, {name}!")
prompt = BasePromptComponent(template, name="Alice") prompt = BasePromptComponent(template, name="Alice")
with pytest.warns(UserWarning, match="Keys provided but not in template: {'age'}"): with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30) prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30)

View File

@ -4,7 +4,7 @@ from kotaemon.prompt.template import PromptTemplate
def test_prompt_template_creation(): def test_prompt_template_creation():
# Test case 1: Ensure the PromptTemplate object is created correctly # Ensure the PromptTemplate object is created correctly
template_string = "This is a template" template_string = "This is a template"
template = PromptTemplate(template_string) template = PromptTemplate(template_string)
assert template.template == template_string assert template.template == template_string
@ -15,8 +15,22 @@ def test_prompt_template_creation():
assert template.placeholders == {"name", "day"} assert template.placeholders == {"name", "day"}
def test_prompt_template_creation_invalid_placeholder():
# Ensure the PromptTemplate object handle invalid placeholder correctly
template_string = "Hello, {name}! Today is {0day}."
with pytest.raises(ValueError):
PromptTemplate(template_string, ignore_invalid=False)
with pytest.warns(
UserWarning,
match="Ignore invalid placeholder: 0day.",
):
PromptTemplate(template_string, ignore_invalid=True)
def test_prompt_template_addition(): def test_prompt_template_addition():
# Test case 2: Ensure the __add__ method concatenates the templates correctly # Ensure the __add__ method concatenates the templates correctly
template1 = PromptTemplate("Hello, ") template1 = PromptTemplate("Hello, ")
template2 = PromptTemplate("world!") template2 = PromptTemplate("world!")
result = template1 + template2 result = template1 + template2
@ -29,25 +43,71 @@ def test_prompt_template_addition():
def test_prompt_template_extract_placeholders(): def test_prompt_template_extract_placeholders():
# Test case 3: Ensure the extract_placeholders method extracts placeholders # Ensure the PromptTemplate correctly extracts placeholders
# correctly
template_string = "Hello, {name}! Today is {day}." template_string = "Hello, {name}! Today is {day}."
result = PromptTemplate.extract_placeholders(template_string) result = PromptTemplate(template_string).placeholders
assert result == {"name", "day"} assert result == {"name", "day"}
def test_prompt_template_populate(): def test_prompt_template_populate():
# Test case 4: Ensure the populate method populates the template correctly # Ensure the populate method populates the template correctly
template_string = "Hello, {name}! Today is {day}." template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string) template = PromptTemplate(template_string)
result = template.populate(name="John", day="Monday") result = template.populate(name="John", day="Monday")
assert result == "Hello, John! Today is Monday." assert result == "Hello, John! Today is Monday."
def test_prompt_template_unknown_placeholder(): def test_prompt_template_check_missing_kwargs():
# Test case 5: Ensure the populate method raises an exception for unknown # Ensure the check_missing_kwargs and populate methods raise an exception for
# placeholders # missing placeholders
template_string = "Hello, {name}! Today is {day}." template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string) template = PromptTemplate(template_string)
kwargs = dict(name="John")
with pytest.raises(ValueError): with pytest.raises(ValueError):
template.populate(name="John", month="January") template.check_missing_kwargs(**kwargs)
with pytest.raises(ValueError):
template.populate(**kwargs)
def test_prompt_template_check_redundant_kwargs():
# Ensure the check_redundant_kwargs, partial_populate and populate methods warn for
# redundant placeholders
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
kwargs = dict(name="John", day="Monday", age="30")
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.check_redundant_kwargs(**kwargs)
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.partial_populate(**kwargs)
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.populate(**kwargs)
def test_prompt_template_populate_complex_template():
# Ensure the populate method produces the same results as the built-in str.format
# function
template_string = (
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
)
template = PromptTemplate(template_string)
kwargs = dict(a=1, b="two", c=3, d=4, e="á")
populated = template.populate(**kwargs)
expected = template_string.format(**kwargs)
assert populated == expected
def test_prompt_template_partial_populate():
# Ensure the partial_populate method populates correctly
template_string = (
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
)
template = PromptTemplate(template_string)
kwargs = dict(a=1, b="two", d=4, e="á")
populated = template.partial_populate(**kwargs)
expected = "a = 1.00, b = two, c = {c:.1%}, d = 4., ascii of á = '\\xe1'"
assert populated == expected