kotaemon/tests/test_agent.py

245 lines
7.7 KiB
Python

from unittest.mock import patch
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.agents.base import AgentType
from kotaemon.agents.langchain import LangchainAgent
from kotaemon.agents.react import ReactAgent
from kotaemon.agents.rewoo import RewooAgent
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool
from kotaemon.llms import AzureChatOpenAI
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
_openai_chat_completion_responses_rewoo = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
for text in [
(
"#Plan1: Search for Cinnamon AI company on Google\n"
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
),
FINAL_RESPONSE_TEXT,
]
]
_openai_chat_completion_responses_react = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
"so I should gather information about it.\n"
"Action: wikipedia\n"
"Action Input: Cinnamon AI company\n"
),
(
"The information retrieved from Wikipedia is not "
"about Cinnamon AI company, but about Blue Prism, "
"a British multinational software corporation. "
"I need to try another source to gather information "
"about Cinnamon AI company.\n"
"Action: google_search\n"
"Action Input: Cinnamon AI company\n"
),
FINAL_RESPONSE_TEXT,
]
]
_openai_chat_completion_responses_react_langchain_tool = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
"so I should gather information about it.\n"
"Action: Wikipedia\n"
"Action Input: Cinnamon AI company\n"
),
(
"The information retrieved from Wikipedia is not "
"about Cinnamon AI company, but about Blue Prism, "
"a British multinational software corporation. "
"I need to try another source to gather information "
"about Cinnamon AI company.\n"
"Action: duckduckgo_search\n"
"Action Input: Cinnamon AI company\n"
),
FINAL_RESPONSE_TEXT,
]
]
@pytest.fixture
def llm():
return AzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2",
temperature=0,
)
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_rewoo,
)
def test_rewoo_agent(openai_completion, llm, mock_google_search):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = RewooAgent(llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.text == FINAL_RESPONSE_TEXT
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_react_agent(openai_completion, llm, mock_google_search):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.text == FINAL_RESPONSE_TEXT
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_react_agent_langchain(openai_completion, llm, mock_google_search):
from langchain.agents import AgentType, initialize_agent
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
agent = initialize_agent(
langchain_plugins,
llm._obj,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_wrapper_agent_langchain(openai_completion, llm, mock_google_search):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = LangchainAgent(
llm=llm,
plugins=plugins,
agent_type=AgentType.react,
)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_react_langchain_tool,
)
def test_react_agent_with_langchain_tools(openai_completion, llm):
from langchain.tools import DuckDuckGoSearchRun, WikipediaQueryRun
from langchain.utilities import WikipediaAPIWrapper
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
search = DuckDuckGoSearchRun()
langchain_plugins = [wikipedia, search]
plugins = [BaseTool.from_langchain_format(tool) for tool in langchain_plugins]
agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.text == FINAL_RESPONSE_TEXT