kotaemon/tests/test_template.py
Nguyen Trung Duc (john) 693ed39de4 Move prompts into LLMs module (#70)
Since the only usage of prompt is within LLMs, it is reasonable to keep it within the LLM module. This way, it would be easier to discover module, and make the code base less complicated.

Changes:

* Move prompt components into llms
* Bump version 0.3.1
* Make pip install dependencies in eager mode

---------

Co-authored-by: ian <ian@cinnamon.is>
2023-11-14 16:00:10 +07:00

114 lines
4.0 KiB
Python

import pytest
from kotaemon.llms import PromptTemplate
def test_prompt_template_creation():
# Ensure the PromptTemplate object is created correctly
template_string = "This is a template"
template = PromptTemplate(template_string)
assert template.template == template_string
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
assert template.template == template_string
assert template.placeholders == {"name", "day"}
def test_prompt_template_creation_invalid_placeholder():
# Ensure the PromptTemplate object handle invalid placeholder correctly
template_string = "Hello, {name}! Today is {0day}."
with pytest.raises(ValueError):
PromptTemplate(template_string, ignore_invalid=False)
with pytest.warns(
UserWarning,
match="Ignore invalid placeholder: 0day.",
):
PromptTemplate(template_string, ignore_invalid=True)
def test_prompt_template_addition():
# Ensure the __add__ method concatenates the templates correctly
template1 = PromptTemplate("Hello, ")
template2 = PromptTemplate("world!")
result = template1 + template2
assert result.template == "Hello, \nworld!"
template1 = PromptTemplate("Hello, {name}!")
template2 = PromptTemplate("Today is {day}.")
result = template1 + template2
assert result.template == "Hello, {name}!\nToday is {day}."
def test_prompt_template_extract_placeholders():
# Ensure the PromptTemplate correctly extracts placeholders
template_string = "Hello, {name}! Today is {day}."
result = PromptTemplate(template_string).placeholders
assert result == {"name", "day"}
def test_prompt_template_populate():
# Ensure the populate method populates the template correctly
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
result = template.populate(name="John", day="Monday")
assert result == "Hello, John! Today is Monday."
def test_prompt_template_check_missing_kwargs():
# Ensure the check_missing_kwargs and populate methods raise an exception for
# missing placeholders
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
kwargs = dict(name="John")
with pytest.raises(ValueError):
template.check_missing_kwargs(**kwargs)
with pytest.raises(ValueError):
template.populate(**kwargs)
def test_prompt_template_check_redundant_kwargs():
# Ensure the check_redundant_kwargs, partial_populate and populate methods warn for
# redundant placeholders
template_string = "Hello, {name}! Today is {day}."
template = PromptTemplate(template_string)
kwargs = dict(name="John", day="Monday", age="30")
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.check_redundant_kwargs(**kwargs)
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.partial_populate(**kwargs)
with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
template.populate(**kwargs)
def test_prompt_template_populate_complex_template():
# Ensure the populate method produces the same results as the built-in str.format
# function
template_string = (
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
)
template = PromptTemplate(template_string)
kwargs = dict(a=1, b="two", c=3, d=4, e="á")
populated = template.populate(**kwargs)
expected = template_string.format(**kwargs)
assert populated == expected
def test_prompt_template_partial_populate():
# Ensure the partial_populate method populates correctly
template_string = (
"a = {a:.2f}, b = {b}, c = {c:.1%}, d = {d:#.0g}, ascii of {e} = {e!a:>2}"
)
template = PromptTemplate(template_string)
kwargs = dict(a=1, b="two", d=4, e="á")
populated = template.partial_populate(**kwargs)
expected = "a = 1.00, b = two, c = {c:.1%}, d = 4., ascii of á = '\\xe1'"
assert populated == expected