[Feat] Add support for f-string syntax in PromptTemplate (#38)
* Add support for f-string syntax in PromptTemplate
This commit is contained in:
parent
56bc41b673
commit
2638152054
|
@ -50,7 +50,7 @@ pip install kotaemon@git+ssh://git@github.com/Cinnamon/kotaemon.git
|
||||||
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
|
This repo uses [git-secret](https://sobolevn.me/git-secret/) to share credentials, which
|
||||||
internally uses `gpg` to encrypt and decrypt secret files.
|
internally uses `gpg` to encrypt and decrypt secret files.
|
||||||
|
|
||||||
This repo uses `python-dotenv` to manage credentials stored as enviroment variable.
|
This repo uses `python-dotenv` to manage credentials stored as environment variable.
|
||||||
Please note that the use of `python-dotenv` and credentials are for development
|
Please note that the use of `python-dotenv` and credentials are for development
|
||||||
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.
|
purposes only. Thus, it should not be used in the main source code (i.e. `kotaemon/` and `tests/`), but can be used in `examples/`.
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import warnings
|
from typing import Callable, Union
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
from kotaemon.documents.base import Document
|
from kotaemon.documents.base import Document
|
||||||
|
@ -39,14 +38,7 @@ class BasePromptComponent(BaseComponent):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
provided_keys = set(kwargs.keys())
|
self.template.check_redundant_kwargs(**kwargs)
|
||||||
expected_keys = self.template.placeholders
|
|
||||||
|
|
||||||
redundant_keys = provided_keys - expected_keys
|
|
||||||
if redundant_keys:
|
|
||||||
warnings.warn(
|
|
||||||
f"Keys provided but not in template: {redundant_keys}", UserWarning
|
|
||||||
)
|
|
||||||
|
|
||||||
def __check_unset_placeholders(self):
|
def __check_unset_placeholders(self):
|
||||||
"""
|
"""
|
||||||
|
@ -62,15 +54,7 @@ class BasePromptComponent(BaseComponent):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
expected_keys = self.template.placeholders
|
self.template.check_missing_kwargs(**self.__dict__)
|
||||||
|
|
||||||
missing_keys = []
|
|
||||||
for key in expected_keys:
|
|
||||||
if key not in self.__dict__:
|
|
||||||
missing_keys.append(key)
|
|
||||||
|
|
||||||
if missing_keys:
|
|
||||||
raise ValueError(f"\nMissing keys in template: {missing_keys}")
|
|
||||||
|
|
||||||
def __validate_value_type(self, **kwargs):
|
def __validate_value_type(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -88,14 +72,12 @@ class BasePromptComponent(BaseComponent):
|
||||||
"""
|
"""
|
||||||
type_error = []
|
type_error = []
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if not isinstance(v, (str, int, Document, BaseComponent)):
|
if not isinstance(v, (str, int, Document, Callable)): # type: ignore
|
||||||
if isinstance(v, int):
|
|
||||||
kwargs[k] = str(v)
|
|
||||||
type_error.append((k, type(v)))
|
type_error.append((k, type(v)))
|
||||||
|
|
||||||
if type_error:
|
if type_error:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Type of values must be either int, str, Document, BaseComponent, "
|
"Type of values must be either int, str, Document, Callable, "
|
||||||
f"found unsupported type for (key, type): {type_error}"
|
f"found unsupported type for (key, type): {type_error}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -138,15 +120,18 @@ class BasePromptComponent(BaseComponent):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
for k in self.template.placeholders:
|
for k in self.template.placeholders:
|
||||||
v = getattr(self, k)
|
v = getattr(self, k)
|
||||||
if isinstance(v, BaseComponent):
|
|
||||||
|
# if get a callable, execute to get its output
|
||||||
|
if isinstance(v, Callable): # type: ignore[arg-type]
|
||||||
v = v()
|
v = v()
|
||||||
|
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = str([__prepare(k, each) for each in v])
|
v = str([__prepare(k, each) for each in v])
|
||||||
elif isinstance(v, (str, int, Document)):
|
elif isinstance(v, (str, int, Document)):
|
||||||
v = __prepare(k, v)
|
v = __prepare(k, v)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported type {type(v)} for template value of key {k}"
|
f"Unsupported type {type(v)} for template value of key `{k}`"
|
||||||
)
|
)
|
||||||
kwargs[k] = v
|
kwargs[k] = v
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import re
|
import warnings
|
||||||
from typing import Set
|
from string import Formatter
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate:
|
class PromptTemplate:
|
||||||
|
@ -7,33 +7,74 @@ class PromptTemplate:
|
||||||
Base class for prompt templates.
|
Base class for prompt templates.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self, template: str, ignore_invalid=True):
|
||||||
def extract_placeholders(template: str) -> Set[str]:
|
template = template
|
||||||
"""
|
formatter = Formatter()
|
||||||
Extracts placeholders from a template string.
|
parsed_template = list(formatter.parse(template))
|
||||||
|
|
||||||
Args:
|
|
||||||
template (str): The template string to extract placeholders from.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
set[str]: A set of placeholder names found in the template string.
|
|
||||||
"""
|
|
||||||
placeholder_regex = r"{([a-zA-Z_][a-zA-Z0-9_]*)}"
|
|
||||||
|
|
||||||
placeholders = set()
|
placeholders = set()
|
||||||
for item in re.findall(placeholder_regex, template):
|
for _, key, _, _ in parsed_template:
|
||||||
if item.isidentifier():
|
if key is None:
|
||||||
placeholders.add(item)
|
continue
|
||||||
|
if not key.isidentifier():
|
||||||
|
if ignore_invalid:
|
||||||
|
warnings.warn(f"Ignore invalid placeholder: {key}.", UserWarning)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Placeholder name must be a valid Python identifier, found:"
|
||||||
|
f" {key}."
|
||||||
|
)
|
||||||
|
placeholders.add(key)
|
||||||
|
|
||||||
return placeholders
|
|
||||||
|
|
||||||
def __init__(self, template: str):
|
|
||||||
self.placeholders = self.extract_placeholders(template)
|
|
||||||
self.template = template
|
self.template = template
|
||||||
|
self.placeholders = placeholders
|
||||||
|
self.__formatter = formatter
|
||||||
|
self.__parsed_template = parsed_template
|
||||||
|
|
||||||
|
def check_missing_kwargs(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Check if all the placeholders in the template are set.
|
||||||
|
|
||||||
|
This function checks if all the expected placeholders in the template are set as
|
||||||
|
attributes of the object. If any placeholders are missing, a `ValueError`
|
||||||
|
is raised with the names of the missing keys.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
missing_keys = self.placeholders.difference(kwargs.keys())
|
||||||
|
if missing_keys:
|
||||||
|
raise ValueError(f"Missing keys in template: {','.join(missing_keys)}")
|
||||||
|
|
||||||
|
def check_redundant_kwargs(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Check if all the placeholders in the template are set.
|
||||||
|
|
||||||
|
This function checks if all the expected placeholders in the template are set as
|
||||||
|
attributes of the object. If any placeholders are missing, a `ValueError`
|
||||||
|
is raised with the names of the missing keys.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
provided_keys = set(kwargs.keys())
|
||||||
|
redundant_keys = provided_keys - self.placeholders
|
||||||
|
|
||||||
|
if redundant_keys:
|
||||||
|
warnings.warn(
|
||||||
|
f"Keys provided but not in template: {','.join(redundant_keys)}",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
def populate(self, **kwargs):
|
def populate(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Populate the template with the given keyword arguments.
|
Strictly populate the template with the given keyword arguments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: The keyword arguments to populate the template.
|
**kwargs: The keyword arguments to populate the template.
|
||||||
|
@ -44,15 +85,46 @@ class PromptTemplate:
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If an unknown placeholder is provided.
|
ValueError: If an unknown placeholder is provided.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
prompt = self.template
|
self.check_missing_kwargs(**kwargs)
|
||||||
for placeholder, value in kwargs.items():
|
|
||||||
if placeholder not in self.placeholders:
|
|
||||||
raise ValueError(f"Unknown placeholder: {placeholder}")
|
|
||||||
prompt = prompt.replace(f"{{{placeholder}}}", value)
|
|
||||||
|
|
||||||
return prompt
|
return self.partial_populate(**kwargs)
|
||||||
|
|
||||||
|
def partial_populate(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Partially populate the template with the given keyword arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: The keyword arguments to populate the template.
|
||||||
|
Each keyword corresponds to a placeholder in the template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The populated template.
|
||||||
|
"""
|
||||||
|
self.check_redundant_kwargs(**kwargs)
|
||||||
|
|
||||||
|
prompt = []
|
||||||
|
for literal_text, field_name, format_spec, conversion in self.__parsed_template:
|
||||||
|
prompt.append(literal_text)
|
||||||
|
|
||||||
|
if field_name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if field_name not in kwargs:
|
||||||
|
if conversion:
|
||||||
|
value = f"{{{field_name}}}!{conversion}:{format_spec}"
|
||||||
|
else:
|
||||||
|
value = f"{{{field_name}:{format_spec}}}"
|
||||||
|
else:
|
||||||
|
value = kwargs[field_name]
|
||||||
|
if conversion is not None:
|
||||||
|
value = self.__formatter.convert_field(value, conversion)
|
||||||
|
if format_spec is not None:
|
||||||
|
value = self.__formatter.format_field(value, format_spec)
|
||||||
|
|
||||||
|
prompt.append(value)
|
||||||
|
|
||||||
|
return "".join(prompt)
|
||||||
|
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -24,7 +24,7 @@ def test_set_attributes():
|
||||||
def test_check_redundant_kwargs():
|
def test_check_redundant_kwargs():
|
||||||
template = PromptTemplate("Hello, {name}!")
|
template = PromptTemplate("Hello, {name}!")
|
||||||
prompt = BasePromptComponent(template, name="Alice")
|
prompt = BasePromptComponent(template, name="Alice")
|
||||||
with pytest.warns(UserWarning, match="Keys provided but not in template: {'age'}"):
|
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
||||||
prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30)
|
prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from kotaemon.prompt.template import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_template_creation():
|
def test_prompt_template_creation():
|
||||||
# Test case 1: Ensure the PromptTemplate object is created correctly
|
# Ensure the PromptTemplate object is created correctly
|
||||||
template_string = "This is a template"
|
template_string = "This is a template"
|
||||||
template = PromptTemplate(template_string)
|
template = PromptTemplate(template_string)
|
||||||
assert template.template == template_string
|
assert template.template == template_string
|
||||||
|
@ -15,8 +15,22 @@ def test_prompt_template_creation():
|
||||||
assert template.placeholders == {"name", "day"}
|
assert template.placeholders == {"name", "day"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_template_creation_invalid_placeholder():
|
||||||
|
# Ensure the PromptTemplate object handle invalid placeholder correctly
|
||||||
|
template_string = "Hello, {name}! Today is {0day}."
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PromptTemplate(template_string, ignore_invalid=False)
|
||||||
|
|
||||||
|
with pytest.warns(
|
||||||
|
UserWarning,
|
||||||
|
match="Ignore invalid placeholder: 0day.",
|
||||||
|
):
|
||||||
|
PromptTemplate(template_string, ignore_invalid=True)
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_template_addition():
|
def test_prompt_template_addition():
|
||||||
# Test case 2: Ensure the __add__ method concatenates the templates correctly
|
# Ensure the __add__ method concatenates the templates correctly
|
||||||
template1 = PromptTemplate("Hello, ")
|
template1 = PromptTemplate("Hello, ")
|
||||||
template2 = PromptTemplate("world!")
|
template2 = PromptTemplate("world!")
|
||||||
result = template1 + template2
|
result = template1 + template2
|
||||||
|
@ -29,25 +43,71 @@ def test_prompt_template_addition():
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_template_extract_placeholders():
|
def test_prompt_template_extract_placeholders():
|
||||||
# Test case 3: Ensure the extract_placeholders method extracts placeholders
|
# Ensure the PromptTemplate correctly extracts placeholders
|
||||||
# correctly
|
|
||||||
template_string = "Hello, {name}! Today is {day}."
|
template_string = "Hello, {name}! Today is {day}."
|
||||||
result = PromptTemplate.extract_placeholders(template_string)
|
result = PromptTemplate(template_string).placeholders
|
||||||
assert result == {"name", "day"}
|
assert result == {"name", "day"}
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_template_populate():
|
def test_prompt_template_populate():
|
||||||
# Test case 4: Ensure the populate method populates the template correctly
|
# Ensure the populate method populates the template correctly
|
||||||
template_string = "Hello, {name}! Today is {day}."
|
template_string = "Hello, {name}! Today is {day}."
|
||||||
template = PromptTemplate(template_string)
|
template = PromptTemplate(template_string)
|
||||||
result = template.populate(name="John", day="Monday")
|
result = template.populate(name="John", day="Monday")
|
||||||
assert result == "Hello, John! Today is Monday."
|
assert result == "Hello, John! Today is Monday."
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_template_unknown_placeholder():
|
def test_prompt_template_check_missing_kwargs():
|
||||||
# Test case 5: Ensure the populate method raises an exception for unknown
|
# Ensure the check_missing_kwargs and populate methods raise an exception for
|
||||||
# placeholders
|
# missing placeholders
|
||||||
template_string = "Hello, {name}! Today is {day}."
|
template_string = "Hello, {name}! Today is {day}."
|
||||||
template = PromptTemplate(template_string)
|
template = PromptTemplate(template_string)
|
||||||
|
kwargs = dict(name="John")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
template.populate(name="John", month="January")
|
template.check_missing_kwargs(**kwargs)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
template.populate(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_template_check_redundant_kwargs():
|
||||||
|
# Ensure the check_redundant_kwargs, partial_populate and populate methods warn for
|
||||||
|
# redundant placeholders
|
||||||
|
template_string = "Hello, {name}! Today is {day}."
|
||||||
|
template = PromptTemplate(template_string)
|
||||||
|
kwargs = dict(name="John", day="Monday", age="30")
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
||||||
|
template.check_redundant_kwargs(**kwargs)
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
||||||
|
template.partial_populate(**kwargs)
|
||||||
|
|
||||||
|
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
||||||
|
template.populate(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_template_populate_complex_template():
|
||||||
|
# Ensure the populate method produces the same results as the built-in str.format
|
||||||
|
# function
|
||||||
|
template_string = (
|
||||||
|
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
|
||||||
|
)
|
||||||
|
template = PromptTemplate(template_string)
|
||||||
|
kwargs = dict(a=1, b="two", c=3, d=4, e="á")
|
||||||
|
populated = template.populate(**kwargs)
|
||||||
|
expected = template_string.format(**kwargs)
|
||||||
|
assert populated == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_template_partial_populate():
|
||||||
|
# Ensure the partial_populate method populates correctly
|
||||||
|
template_string = (
|
||||||
|
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
|
||||||
|
)
|
||||||
|
template = PromptTemplate(template_string)
|
||||||
|
kwargs = dict(a=1, b="two", d=4, e="á")
|
||||||
|
populated = template.partial_populate(**kwargs)
|
||||||
|
expected = "a = 1.00, b = two, c = {c:.1%}, d = 4., ascii of á = '\\xe1'"
|
||||||
|
assert populated == expected
|
||||||
|
|
Loading…
Reference in New Issue
Block a user