kotaemon/tests/test_composite.py
Nguyen Trung Duc (john) 693ed39de4 Move prompts into LLMs module (#70)
Since the only usage of prompt is within LLMs, it is reasonable to keep it within the LLM module. This way, it would be easier to discover module, and make the code base less complicated.

Changes:

* Move prompt components into llms
* Bump version 0.3.1
* Make pip install dependencies in eager mode

---------

Co-authored-by: ian <ian@cinnamon.is>
2023-11-14 16:00:10 +07:00

173 lines
5.2 KiB
Python

from copy import deepcopy
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.composite import (
GatedBranchingPipeline,
GatedLinearPipeline,
SimpleBranchingPipeline,
SimpleLinearPipeline,
)
from kotaemon.llms import BasePromptComponent
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
_openai_chat_completion_response = ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": "This is a test 123",
"finish_reason": "length",
"logprobs": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
@pytest.fixture
def mock_llm():
return AzureChatOpenAI(
openai_api_base="OPENAI_API_BASE",
openai_api_key="OPENAI_API_KEY",
openai_api_version="OPENAI_API_VERSION",
deployment_name="dummy-q2-gpt35",
temperature=0,
)
@pytest.fixture
def mock_post_processor():
return RegexExtractor(pattern=r"\d+")
@pytest.fixture
def mock_prompt():
return BasePromptComponent(template="Test prompt {value}")
@pytest.fixture
def mock_simple_linear_pipeline(mock_prompt, mock_llm, mock_post_processor):
return SimpleLinearPipeline(
prompt=mock_prompt, llm=mock_llm, post_processor=mock_post_processor
)
@pytest.fixture
def mock_gated_linear_pipeline_positive(mock_prompt, mock_llm, mock_post_processor):
return GatedLinearPipeline(
prompt=mock_prompt,
llm=mock_llm,
post_processor=mock_post_processor,
condition=RegexExtractor(pattern="positive"),
)
@pytest.fixture
def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_processor):
return GatedLinearPipeline(
prompt=mock_prompt,
llm=mock_llm,
post_processor=mock_post_processor,
condition=RegexExtractor(pattern="negative"),
)
def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline):
openai_mocker = mocker.patch(
"openai.resources.chat.completions.Completions.create",
return_value=_openai_chat_completion_response,
)
result = mock_simple_linear_pipeline(value="abc")
assert result.text == "123"
assert openai_mocker.call_count == 1
def test_gated_linear_pipeline_run_positive(
mocker, mock_gated_linear_pipeline_positive
):
openai_mocker = mocker.patch(
"openai.resources.chat.completions.Completions.create",
return_value=_openai_chat_completion_response,
)
result = mock_gated_linear_pipeline_positive(
value="abc", condition_text="positive condition"
)
assert result.text == "123"
assert openai_mocker.call_count == 1
def test_gated_linear_pipeline_run_negative(
mocker, mock_gated_linear_pipeline_positive
):
openai_mocker = mocker.patch(
"openai.resources.chat.completions.Completions.create",
return_value=_openai_chat_completion_response,
)
result = mock_gated_linear_pipeline_positive(
value="abc", condition_text="negative condition"
)
assert result.content is None
assert openai_mocker.call_count == 0
def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline):
response0: ChatCompletion = _openai_chat_completion_response
response1: ChatCompletion = deepcopy(_openai_chat_completion_response)
response1.choices[0].message.content = "a quick brown fox"
response2: ChatCompletion = deepcopy(_openai_chat_completion_response)
response2.choices[0].message.content = "jumps over the lazy dog 456"
openai_mocker = mocker.patch(
"openai.resources.chat.completions.Completions.create",
side_effect=[response0, response1, response2],
)
pipeline = SimpleBranchingPipeline()
for _ in range(3):
pipeline.add_branch(mock_simple_linear_pipeline)
result = pipeline.run(value="abc")
texts = [each.text for each in result]
assert len(result) == 3
assert texts == ["123", "", "456"]
assert openai_mocker.call_count == 3
def test_simple_gated_branching_pipeline_run(
mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative
):
response0: ChatCompletion = deepcopy(_openai_chat_completion_response)
response0.choices[0].message.content = "a quick brown fox"
openai_mocker = mocker.patch(
"openai.resources.chat.completions.Completions.create",
return_value=response0,
)
pipeline = GatedBranchingPipeline()
pipeline.add_branch(mock_gated_linear_pipeline_negative)
pipeline.add_branch(mock_gated_linear_pipeline_positive)
pipeline.add_branch(mock_gated_linear_pipeline_positive)
result = pipeline.run(value="abc", condition_text="positive condition")
assert result.text == ""
assert openai_mocker.call_count == 2