Update Base interface of Index/Retrieval pipeline (#36)

* add base Tool

* minor update test_tool

* update test dependency

* update test dependency

* Fix namespace conflict

* update test

* add base Agent Interface, add ReWoo Agent

* minor update

* update test

* fix typo

* remove unneeded print

* update rewoo agent

* add LLMTool

* update BaseAgent type

* add ReAct agent

* add ReAct agent

* minor update

* minor update

* minor update

* minor update

* update base reader with BaseComponent

* add splitter

* update agent and tool

* update vectorstores

* update load/save for indexing and retrieving pipeline

* update test_agent for more use-cases

* add missing dependency for test

* update test case for in memory vectorstore

* add TextSplitter to BaseComponent

* update type hint basetool

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2023-10-04 14:27:44 +07:00
committed by GitHub
parent 49ed3f6994
commit 56bc41b673
13 changed files with 302 additions and 36 deletions

View File

@@ -1,11 +1,19 @@
from unittest.mock import patch
import pytest
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents.react import ReactAgent
from kotaemon.pipelines.agents.rewoo import RewooAgent
from kotaemon.pipelines.tools import GoogleSearchTool, LLMTool, WikipediaTool
from kotaemon.pipelines.tools import (
BaseTool,
GoogleSearchTool,
LLMTool,
WikipediaTool,
)
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
_openai_chat_completion_responses_rewoo = [
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
@@ -73,19 +81,61 @@ _openai_chat_completion_responses_react = [
]
]
_openai_chat_completion_responses_react_langchain_tool = [
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
"so I should gather information about it.\n"
"Action: Wikipedia\n"
"Action Input: Cinnamon AI company\n"
),
(
"The information retrieved from Wikipedia is not "
"about Cinnamon AI company, but about Blue Prism, "
"a British multinational software corporation. "
"I need to try another source to gather information "
"about Cinnamon AI company.\n"
"Action: duckduckgo_search\n"
"Action Input: Cinnamon AI company\n"
),
FINAL_RESPONSE_TEXT,
]
]
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_rewoo,
)
def test_rewoo_agent(openai_completion):
llm = AzureChatOpenAI(
@pytest.fixture
def llm():
return AzureChatOpenAI(
openai_api_base="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2",
temperature=0,
)
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_rewoo,
)
def test_rewoo_agent(openai_completion, llm):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
@@ -103,14 +153,7 @@ def test_rewoo_agent(openai_completion):
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_react_agent(openai_completion):
llm = AzureChatOpenAI(
openai_api_base="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2",
temperature=0,
)
def test_react_agent(openai_completion, llm):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
@@ -121,3 +164,47 @@ def test_react_agent(openai_completion):
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_react_agent_langchain(openai_completion, llm):
from langchain.agents import AgentType, initialize_agent
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
agent = initialize_agent(
langchain_plugins,
llm.agent,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True,
)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react_langchain_tool,
)
def test_react_agent_with_langchain_tools(openai_completion, llm):
from langchain.tools import DuckDuckGoSearchRun, WikipediaQueryRun
from langchain.utilities import WikipediaAPIWrapper
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
search = DuckDuckGoSearchRun()
langchain_plugins = [wikipedia, search]
plugins = [BaseTool.from_langchain_format(tool) for tool in langchain_plugins]
agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT

View File

@@ -116,8 +116,8 @@ class TestInMemoryVectorStore:
"3" not in data["text_id_to_ref_doc_id"]
), "delete function does not delete data completely"
db2 = InMemoryVectorStore()
output = db2.load(load_path=tmp_path / "test_save_load_delete.json")
assert output.get("2") == [
db2.load(load_path=tmp_path / "test_save_load_delete.json")
assert db2.get("2") == [
0.4,
0.5,
0.6,