[Feat] Add support for f-string syntax in PromptTemplate (#38)
* Add support for f-string syntax in PromptTemplate
This commit is contained in:
parent
56bc41b673
commit
2638152054
|
@ -50,7 +50,7 @@ pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git
|
|||
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
|
||||
internally uses `gpg` to encrypt and decrypt secret files.
|
||||
|
||||
This repo uses `python-dotenv` to manage credentials stored as enviroment variable.
|
||||
This repo uses `python-dotenv` to manage credentials stored as environment variable.
|
||||
Please note that the use of `python-dotenv` and credentials are for development
|
||||
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import warnings
|
||||
from typing import Union
|
||||
from typing import Callable, Union
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.documents.base import Document
|
||||
|
@ -39,14 +38,7 @@ class BasePromptComponent(BaseComponent):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
provided_keys = set(kwargs.keys())
|
||||
expected_keys = self.template.placeholders
|
||||
|
||||
redundant_keys = provided_keys - expected_keys
|
||||
if redundant_keys:
|
||||
warnings.warn(
|
||||
f"Keys provided but not in template: {redundant_keys}", UserWarning
|
||||
)
|
||||
self.template.check_redundant_kwargs(**kwargs)
|
||||
|
||||
def __check_unset_placeholders(self):
|
||||
"""
|
||||
|
@ -62,15 +54,7 @@ class BasePromptComponent(BaseComponent):
|
|||
Returns:
|
||||
None
|
||||
"""
|
||||
expected_keys = self.template.placeholders
|
||||
|
||||
missing_keys = []
|
||||
for key in expected_keys:
|
||||
if key not in self.__dict__:
|
||||
missing_keys.append(key)
|
||||
|
||||
if missing_keys:
|
||||
raise ValueError(f"\nMissing keys in template: {missing_keys}")
|
||||
self.template.check_missing_kwargs(**self.__dict__)
|
||||
|
||||
def __validate_value_type(self, **kwargs):
|
||||
"""
|
||||
|
@ -88,14 +72,12 @@ class BasePromptComponent(BaseComponent):
|
|||
"""
|
||||
type_error = []
|
||||
for k, v in kwargs.items():
|
||||
if not isinstance(v, (str, int, Document, BaseComponent)):
|
||||
if isinstance(v, int):
|
||||
kwargs[k] = str(v)
|
||||
if not isinstance(v, (str, int, Document, Callable)): # type: ignore
|
||||
type_error.append((k, type(v)))
|
||||
|
||||
if type_error:
|
||||
raise ValueError(
|
||||
"Type of values must be either int, str, Document, BaseComponent, "
|
||||
"Type of values must be either int, str, Document, Callable, "
|
||||
f"found unsupported type for (key, type): {type_error}"
|
||||
)
|
||||
|
||||
|
@ -138,15 +120,18 @@ class BasePromptComponent(BaseComponent):
|
|||
kwargs = {}
|
||||
for k in self.template.placeholders:
|
||||
v = getattr(self, k)
|
||||
if isinstance(v, BaseComponent):
|
||||
|
||||
# if get a callable, execute to get its output
|
||||
if isinstance(v, Callable): # type: ignore[arg-type]
|
||||
v = v()
|
||||
|
||||
if isinstance(v, list):
|
||||
v = str([__prepare(k, each) for each in v])
|
||||
elif isinstance(v, (str, int, Document)):
|
||||
v = __prepare(k, v)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported type {type(v)} for template value of key {k}"
|
||||
f"Unsupported type {type(v)} for template value of key `{k}`"
|
||||
)
|
||||
kwargs[k] = v
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import re
|
||||
from typing import Set
|
||||
import warnings
|
||||
from string import Formatter
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
|
@ -7,33 +7,74 @@ class PromptTemplate:
|
|||
Base class for prompt templates.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def extract_placeholders(template: str) -> Set[str]:
|
||||
"""
|
||||
Extracts placeholders from a template string.
|
||||
|
||||
Args:
|
||||
template (str): The template string to extract placeholders from.
|
||||
|
||||
Returns:
|
||||
set[str]: A set of placeholder names found in the template string.
|
||||
"""
|
||||
placeholder_regex = r"{([a-zA-Z_][a-zA-Z0-9_]*)}"
|
||||
def __init__(self, template: str, ignore_invalid=True):
|
||||
template = template
|
||||
formatter = Formatter()
|
||||
parsed_template = list(formatter.parse(template))
|
||||
|
||||
placeholders = set()
|
||||
for item in re.findall(placeholder_regex, template):
|
||||
if item.isidentifier():
|
||||
placeholders.add(item)
|
||||
for _, key, _, _ in parsed_template:
|
||||
if key is None:
|
||||
continue
|
||||
if not key.isidentifier():
|
||||
if ignore_invalid:
|
||||
warnings.warn(f"Ignore invalid placeholder: {key}.", UserWarning)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Placeholder name must be a valid Python identifier, found:"
|
||||
f" {key}."
|
||||
)
|
||||
placeholders.add(key)
|
||||
|
||||
return placeholders
|
||||
|
||||
def __init__(self, template: str):
|
||||
self.placeholders = self.extract_placeholders(template)
|
||||
self.template = template
|
||||
self.placeholders = placeholders
|
||||
self.__formatter = formatter
|
||||
self.__parsed_template = parsed_template
|
||||
|
||||
def check_missing_kwargs(self, **kwargs):
|
||||
"""
|
||||
Check if all the placeholders in the template are set.
|
||||
|
||||
This function checks if all the expected placeholders in the template are set as
|
||||
attributes of the object. If any placeholders are missing, a `ValueError`
|
||||
is raised with the names of the missing keys.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
missing_keys = self.placeholders.difference(kwargs.keys())
|
||||
if missing_keys:
|
||||
raise ValueError(f"Missing keys in template: {','.join(missing_keys)}")
|
||||
|
||||
def check_redundant_kwargs(self, **kwargs):
|
||||
"""
|
||||
Check if all the placeholders in the template are set.
|
||||
|
||||
This function checks if all the expected placeholders in the template are set as
|
||||
attributes of the object. If any placeholders are missing, a `ValueError`
|
||||
is raised with the names of the missing keys.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
provided_keys = set(kwargs.keys())
|
||||
redundant_keys = provided_keys - self.placeholders
|
||||
|
||||
if redundant_keys:
|
||||
warnings.warn(
|
||||
f"Keys provided but not in template: {','.join(redundant_keys)}",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
def populate(self, **kwargs):
|
||||
"""
|
||||
Populate the template with the given keyword arguments.
|
||||
Strictly populate the template with the given keyword arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: The keyword arguments to populate the template.
|
||||
|
@ -44,15 +85,46 @@ class PromptTemplate:
|
|||
|
||||
Raises:
|
||||
ValueError: If an unknown placeholder is provided.
|
||||
|
||||
"""
|
||||
prompt = self.template
|
||||
for placeholder, value in kwargs.items():
|
||||
if placeholder not in self.placeholders:
|
||||
raise ValueError(f"Unknown placeholder: {placeholder}")
|
||||
prompt = prompt.replace(f"{{{placeholder}}}", value)
|
||||
self.check_missing_kwargs(**kwargs)
|
||||
|
||||
return prompt
|
||||
return self.partial_populate(**kwargs)
|
||||
|
||||
def partial_populate(self, **kwargs):
|
||||
"""
|
||||
Partially populate the template with the given keyword arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: The keyword arguments to populate the template.
|
||||
Each keyword corresponds to a placeholder in the template.
|
||||
|
||||
Returns:
|
||||
str: The populated template.
|
||||
"""
|
||||
self.check_redundant_kwargs(**kwargs)
|
||||
|
||||
prompt = []
|
||||
for literal_text, field_name, format_spec, conversion in self.__parsed_template:
|
||||
prompt.append(literal_text)
|
||||
|
||||
if field_name is None:
|
||||
continue
|
||||
|
||||
if field_name not in kwargs:
|
||||
if conversion:
|
||||
value = f"{{{field_name}}}!{conversion}:{format_spec}"
|
||||
else:
|
||||
value = f"{{{field_name}:{format_spec}}}"
|
||||
else:
|
||||
value = kwargs[field_name]
|
||||
if conversion is not None:
|
||||
value = self.__formatter.convert_field(value, conversion)
|
||||
if format_spec is not None:
|
||||
value = self.__formatter.format_field(value, format_spec)
|
||||
|
||||
prompt.append(value)
|
||||
|
||||
return "".join(prompt)
|
||||
|
||||
def __add__(self, other):
|
||||
"""
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_set_attributes():
|
|||
def test_check_redundant_kwargs():
|
||||
template = PromptTemplate("Hello, {name}!")
|
||||
prompt = BasePromptComponent(template, name="Alice")
|
||||
with pytest.warns(UserWarning, match="Keys provided but not in template: {'age'}"):
|
||||
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
||||
prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30)
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from kotaemon.prompt.template import PromptTemplate
|
|||
|
||||
|
||||
def test_prompt_template_creation():
|
||||
# Test case 1: Ensure the PromptTemplate object is created correctly
|
||||
# Ensure the PromptTemplate object is created correctly
|
||||
template_string = "This is a template"
|
||||
template = PromptTemplate(template_string)
|
||||
assert template.template == template_string
|
||||
|
@ -15,8 +15,22 @@ def test_prompt_template_creation():
|
|||
assert template.placeholders == {"name", "day"}
|
||||
|
||||
|
||||
def test_prompt_template_creation_invalid_placeholder():
|
||||
# Ensure the PromptTemplate object handle invalid placeholder correctly
|
||||
template_string = "Hello, {name}! Today is {0day}."
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
PromptTemplate(template_string, ignore_invalid=False)
|
||||
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="Ignore invalid placeholder: 0day.",
|
||||
):
|
||||
PromptTemplate(template_string, ignore_invalid=True)
|
||||
|
||||
|
||||
def test_prompt_template_addition():
|
||||
# Test case 2: Ensure the __add__ method concatenates the templates correctly
|
||||
# Ensure the __add__ method concatenates the templates correctly
|
||||
template1 = PromptTemplate("Hello, ")
|
||||
template2 = PromptTemplate("world!")
|
||||
result = template1 + template2
|
||||
|
@ -29,25 +43,71 @@ def test_prompt_template_addition():
|
|||
|
||||
|
||||
def test_prompt_template_extract_placeholders():
|
||||
# Test case 3: Ensure the extract_placeholders method extracts placeholders
|
||||
# correctly
|
||||
# Ensure the PromptTemplate correctly extracts placeholders
|
||||
template_string = "Hello, {name}! Today is {day}."
|
||||
result = PromptTemplate.extract_placeholders(template_string)
|
||||
result = PromptTemplate(template_string).placeholders
|
||||
assert result == {"name", "day"}
|
||||
|
||||
|
||||
def test_prompt_template_populate():
|
||||
# Test case 4: Ensure the populate method populates the template correctly
|
||||
# Ensure the populate method populates the template correctly
|
||||
template_string = "Hello, {name}! Today is {day}."
|
||||
template = PromptTemplate(template_string)
|
||||
result = template.populate(name="John", day="Monday")
|
||||
assert result == "Hello, John! Today is Monday."
|
||||
|
||||
|
||||
def test_prompt_template_unknown_placeholder():
|
||||
# Test case 5: Ensure the populate method raises an exception for unknown
|
||||
# placeholders
|
||||
def test_prompt_template_check_missing_kwargs():
|
||||
# Ensure the check_missing_kwargs and populate methods raise an exception for
|
||||
# missing placeholders
|
||||
template_string = "Hello, {name}! Today is {day}."
|
||||
template = PromptTemplate(template_string)
|
||||
kwargs = dict(name="John")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
template.populate(name="John", month="January")
|
||||
template.check_missing_kwargs(**kwargs)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
template.populate(**kwargs)
|
||||
|
||||
|
||||
def test_prompt_template_check_redundant_kwargs():
|
||||
# Ensure the check_redundant_kwargs, partial_populate and populate methods warn for
|
||||
# redundant placeholders
|
||||
template_string = "Hello, {name}! Today is {day}."
|
||||
template = PromptTemplate(template_string)
|
||||
kwargs = dict(name="John", day="Monday", age="30")
|
||||
|
||||
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
||||
template.check_redundant_kwargs(**kwargs)
|
||||
|
||||
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
||||
template.partial_populate(**kwargs)
|
||||
|
||||
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
||||
template.populate(**kwargs)
|
||||
|
||||
|
||||
def test_prompt_template_populate_complex_template():
|
||||
# Ensure the populate method produces the same results as the built-in str.format
|
||||
# function
|
||||
template_string = (
|
||||
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
|
||||
)
|
||||
template = PromptTemplate(template_string)
|
||||
kwargs = dict(a=1, b="two", c=3, d=4, e="á")
|
||||
populated = template.populate(**kwargs)
|
||||
expected = template_string.format(**kwargs)
|
||||
assert populated == expected
|
||||
|
||||
|
||||
def test_prompt_template_partial_populate():
|
||||
# Ensure the partial_populate method populates correctly
|
||||
template_string = (
|
||||
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
|
||||
)
|
||||
template = PromptTemplate(template_string)
|
||||
kwargs = dict(a=1, b="two", d=4, e="á")
|
||||
populated = template.partial_populate(**kwargs)
|
||||
expected = "a = 1.00, b = two, c = {c:.1%}, d = 4., ascii of á = '\\xe1'"
|
||||
assert populated == expected
|
||||
|
|
Loading…
Reference in New Issue
Block a user