refractor agents (#100)

* refractor agents

* minor cosmetic, add terminal ui for cli

* pump to 0.3.4

* Add temporary path

* fix unclose files in tests

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2023-12-06 17:06:29 +07:00
committed by GitHub
parent d9e925eb75
commit 797df5a69c
21 changed files with 281 additions and 228 deletions

View File

@@ -3,73 +3,69 @@ from unittest.mock import patch
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.agents.base import AgentType
from kotaemon.agents.langchain import LangchainAgent
from kotaemon.agents.react import ReactAgent
from kotaemon.agents.rewoo import RewooAgent
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool
from kotaemon.agents import (
AgentType,
BaseTool,
GoogleSearchTool,
LangchainAgent,
LLMTool,
ReactAgent,
RewooAgent,
WikipediaTool,
)
from kotaemon.llms import AzureChatOpenAI
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
REWOO_VALID_PLAN = (
"#Plan1: Search for Cinnamon AI company on Google\n"
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
)
REWOO_INVALID_PLAN = (
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
)
def generate_chat_completion_obj(text):
return ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
_openai_chat_completion_responses_rewoo = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
for text in [
(
"#Plan1: Search for Cinnamon AI company on Google\n"
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
),
FINAL_RESPONSE_TEXT,
]
generate_chat_completion_obj(text=text)
for text in [REWOO_VALID_PLAN, FINAL_RESPONSE_TEXT]
]
_openai_chat_completion_responses_rewoo_error = [
generate_chat_completion_obj(text=text)
for text in [REWOO_INVALID_PLAN, FINAL_RESPONSE_TEXT]
]
_openai_chat_completion_responses_react = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
generate_chat_completion_obj(text=text)
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
@@ -91,28 +87,7 @@ _openai_chat_completion_responses_react = [
]
_openai_chat_completion_responses_react_langchain_tool = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
generate_chat_completion_obj(text=text)
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
@@ -145,6 +120,25 @@ def llm():
)
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_rewoo_error,
)
def test_agent_fail(openai_completion, llm, mock_google_search):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert not response
assert response.status == "failed"
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_rewoo,
@@ -156,7 +150,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search):
LLMTool(llm=llm),
]
agent = RewooAgent(llm=llm, plugins=plugins)
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()