114 lines
4.1 KiB
Python
114 lines
4.1 KiB
Python
import pytest
|
|
|
|
from kotaemon.prompt.template import PromptTemplate
|
|
|
|
|
|
def test_prompt_template_creation():
|
|
# Ensure the PromptTemplate object is created correctly
|
|
template_string = "This is a template"
|
|
template = PromptTemplate(template_string)
|
|
assert template.template == template_string
|
|
|
|
template_string = "Hello, {name}! Today is {day}."
|
|
template = PromptTemplate(template_string)
|
|
assert template.template == template_string
|
|
assert template.placeholders == {"name", "day"}
|
|
|
|
|
|
def test_prompt_template_creation_invalid_placeholder():
|
|
# Ensure the PromptTemplate object handle invalid placeholder correctly
|
|
template_string = "Hello, {name}! Today is {0day}."
|
|
|
|
with pytest.raises(ValueError):
|
|
PromptTemplate(template_string, ignore_invalid=False)
|
|
|
|
with pytest.warns(
|
|
UserWarning,
|
|
match="Ignore invalid placeholder: 0day.",
|
|
):
|
|
PromptTemplate(template_string, ignore_invalid=True)
|
|
|
|
|
|
def test_prompt_template_addition():
|
|
# Ensure the __add__ method concatenates the templates correctly
|
|
template1 = PromptTemplate("Hello, ")
|
|
template2 = PromptTemplate("world!")
|
|
result = template1 + template2
|
|
assert result.template == "Hello, \nworld!"
|
|
|
|
template1 = PromptTemplate("Hello, {name}!")
|
|
template2 = PromptTemplate("Today is {day}.")
|
|
result = template1 + template2
|
|
assert result.template == "Hello, {name}!\nToday is {day}."
|
|
|
|
|
|
def test_prompt_template_extract_placeholders():
|
|
# Ensure the PromptTemplate correctly extracts placeholders
|
|
template_string = "Hello, {name}! Today is {day}."
|
|
result = PromptTemplate(template_string).placeholders
|
|
assert result == {"name", "day"}
|
|
|
|
|
|
def test_prompt_template_populate():
|
|
# Ensure the populate method populates the template correctly
|
|
template_string = "Hello, {name}! Today is {day}."
|
|
template = PromptTemplate(template_string)
|
|
result = template.populate(name="John", day="Monday")
|
|
assert result == "Hello, John! Today is Monday."
|
|
|
|
|
|
def test_prompt_template_check_missing_kwargs():
|
|
# Ensure the check_missing_kwargs and populate methods raise an exception for
|
|
# missing placeholders
|
|
template_string = "Hello, {name}! Today is {day}."
|
|
template = PromptTemplate(template_string)
|
|
kwargs = dict(name="John")
|
|
|
|
with pytest.raises(ValueError):
|
|
template.check_missing_kwargs(**kwargs)
|
|
|
|
with pytest.raises(ValueError):
|
|
template.populate(**kwargs)
|
|
|
|
|
|
def test_prompt_template_check_redundant_kwargs():
|
|
# Ensure the check_redundant_kwargs, partial_populate and populate methods warn for
|
|
# redundant placeholders
|
|
template_string = "Hello, {name}! Today is {day}."
|
|
template = PromptTemplate(template_string)
|
|
kwargs = dict(name="John", day="Monday", age="30")
|
|
|
|
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
|
template.check_redundant_kwargs(**kwargs)
|
|
|
|
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
|
template.partial_populate(**kwargs)
|
|
|
|
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
|
|
template.populate(**kwargs)
|
|
|
|
|
|
def test_prompt_template_populate_complex_template():
|
|
# Ensure the populate method produces the same results as the built-in str.format
|
|
# function
|
|
template_string = (
|
|
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
|
|
)
|
|
template = PromptTemplate(template_string)
|
|
kwargs = dict(a=1, b="two", c=3, d=4, e="á")
|
|
populated = template.populate(**kwargs)
|
|
expected = template_string.format(**kwargs)
|
|
assert populated == expected
|
|
|
|
|
|
def test_prompt_template_partial_populate():
|
|
# Ensure the partial_populate method populates correctly
|
|
template_string = (
|
|
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
|
|
)
|
|
template = PromptTemplate(template_string)
|
|
kwargs = dict(a=1, b="two", d=4, e="á")
|
|
populated = template.partial_populate(**kwargs)
|
|
expected = "a = 1.00, b = two, c = {c:.1%}, d = 4., ascii of á = '\\xe1'"
|
|
assert populated == expected
|