[AUR-431, AUR-435] Add Agent Interface and ReWOO Agent implementation (#31)

* add base Tool

* minor update test_tool

* update test dependency

* update test dependency

* Fix namespace conflict

* update test

* add base Agent Interface, add ReWoo Agent

* minor update

* update test

* fix typo

* remove unneeded print

* update rewoo agent

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-10-01 11:53:08 +07:00 committed by GitHub
parent f9fc02a32a
commit 91048770fa
11 changed files with 919 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from .base import BaseAgent
__all__ = ["BaseAgent"]

View File

@ -0,0 +1,67 @@
from enum import Enum
from typing import Dict, List, Union
from pydantic import BaseModel
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.pipelines.tools import BaseTool
from kotaemon.prompt.template import PromptTemplate
BaseLLM = Union[ChatLLM, LLM]
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
openai_memory = "openai_memory"
@staticmethod
def get_agent_class(_type: "AgentType"):
"""
Get agent class from agent type.
:param _type: agent type
:return: agent class
"""
if _type == AgentType.rewoo:
from .rewoo.agent import RewooAgent
return RewooAgent
else:
raise ValueError(f"Unknown agent type: {_type}")
class AgentOutput(BaseModel):
"""
Pydantic model for agent output.
"""
output: str
cost: float
token_usage: int
class BaseAgent(BaseTool):
name: str
"""Name of the agent."""
type: AgentType
"""Agent type, must be one of AgentType"""
description: str
"""Description used to tell the model how/when/why to use the agent.
You can provide few-shot examples as a part of the description. This will be
input to the prompt of LLM."""
llm: Union[BaseLLM, Dict[str, BaseLLM]]
"""Specify LLM to be used in the model, cam be a dict to supply different
LLMs to multiple purposes in the agent"""
prompt_template: Union[PromptTemplate, Dict[str, PromptTemplate]]
"""A prompt template or a dict to supply different prompt to the agent
"""
plugins: List[BaseTool]
"""List of plugins / tools to be used in the agent
"""

View File

@ -0,0 +1,219 @@
import json
import logging
import os
from dataclasses import dataclass
from typing import Any, Dict, NamedTuple, Union
def check_log():
"""
Checks if logging has been enabled.
:return: True if logging has been enabled, False otherwise.
:rtype: bool
"""
return os.environ.get("LOG_PATH", None) is not None
class BaseScratchPad:
"""
Base class for output handlers.
Attributes:
-----------
logger : logging.Logger
The logger object to log messages.
Methods:
--------
stop():
Stop the output.
update_status(output: str, **kwargs):
Update the status of the output.
thinking(name: str):
Log that a process is thinking.
done(_all=False):
Log that the process is done.
stream_print(item: str):
Not implemented.
json_print(item: Dict[str, Any]):
Log a JSON object.
panel_print(item: Any, title: str = "Output", stream: bool = False):
Log a panel output.
clear():
Not implemented.
print(content: str, **kwargs):
Log arbitrary content.
format_json(json_obj: str):
Format a JSON object.
debug(content: str, **kwargs):
Log a debug message.
info(content: str, **kwargs):
Log an informational message.
warning(content: str, **kwargs):
Log a warning message.
error(content: str, **kwargs):
Log an error message.
critical(content: str, **kwargs):
Log a critical message.
"""
def __init__(self):
"""
Initialize the BaseOutput object.
"""
self.logger = logging
self.log = []
def stop(self):
"""
Stop the output.
"""
def update_status(self, output: str, **kwargs):
"""
Update the status of the output.
"""
if check_log():
self.logger.info(output)
def thinking(self, name: str):
"""
Log that a process is thinking.
"""
if check_log():
self.logger.info(f"{name} is thinking...")
def done(self, _all=False):
"""
Log that the process is done.
"""
if check_log():
self.logger.info("Done")
def stream_print(self, item: str):
"""
Stream print.
"""
def json_print(self, item: Dict[str, Any]):
"""
Log a JSON object.
"""
if check_log():
self.logger.info(json.dumps(item, indent=2))
def panel_print(self, item: Any, title: str = "Output", stream: bool = False):
"""
Log a panel output.
Args:
item : Any
The item to log.
title : str, optional
The title of the panel, defaults to "Output".
stream : bool, optional
"""
if not stream:
self.log.append(item)
if check_log():
self.logger.info("-" * 20)
self.logger.info(item)
self.logger.info("-" * 20)
def clear(self):
"""
Not implemented.
"""
def print(self, content: str, **kwargs):
"""
Log arbitrary content.
"""
self.log.append(content)
if check_log():
self.logger.info(content)
def format_json(self, json_obj: str):
"""
Format a JSON object.
"""
formatted_json = json.dumps(json_obj, indent=2)
return formatted_json
def debug(self, content: str, **kwargs):
"""
Log a debug message.
"""
if check_log():
self.logger.debug(content, **kwargs)
def info(self, content: str, **kwargs):
"""
Log an informational message.
"""
if check_log():
self.logger.info(content, **kwargs)
def warning(self, content: str, **kwargs):
"""
Log a warning message.
"""
if check_log():
self.logger.warning(content, **kwargs)
def error(self, content: str, **kwargs):
"""
Log an error message.
"""
if check_log():
self.logger.error(content, **kwargs)
def critical(self, content: str, **kwargs):
"""
Log a critical message.
"""
if check_log():
self.logger.critical(content, **kwargs)
@dataclass
class AgentAction:
"""Agent's action to take.
Args:
tool: The tool to invoke.
tool_input: The input to the tool.
log: The log message.
"""
tool: str
tool_input: Union[str, dict]
log: str
class AgentFinish(NamedTuple):
"""Agent's return value when finishing execution.
Args:
return_values: The return values of the agent.
log: The log message.
"""
return_values: dict
log: str

View File

@ -0,0 +1,3 @@
from .agent import RewooAgent
__all__ = ["RewooAgent"]

View File

@ -0,0 +1,270 @@
import logging
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, create_model
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.prompt.template import PromptTemplate
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from ..utils import get_plugin_response_content
from .planner import Planner
from .solver import Solver
class RewooAgent(BaseAgent):
"""Distributive RewooAgent class inherited from BaseAgent.
Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf"""
name: str = "RewooAgent"
type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions"
llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx}
prompt_template: Dict[
str, PromptTemplate
] = dict() # {"Planner": xxx, "Solver": xxx}
plugins: List[BaseTool] = list()
examples: Dict[str, Union[str, List[str]]] = dict()
args_schema: Optional[Type[BaseModel]] = create_model(
"ReactArgsSchema", instruction=(str, ...)
)
def _get_llms(self):
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM):
return {"Planner": self.llm, "Solver": self.llm}
elif (
isinstance(self.llm, dict)
and "Planner" in self.llm
and "Solver" in self.llm
):
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
else:
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
def _parse_plan_map(
self, planner_response: str
) -> Tuple[Dict[str, List[str]], Dict[str, str]]:
"""
Parse planner output. It should be an n-to-n mapping from Plans to #Es.
This is because sometimes LLM cannot follow the strict output format.
Example:
#Plan1
#E1
#E2
should result in: {"#Plan1": ["#E1", "#E2"]}
Or:
#Plan1
#Plan2
#E1
should result in: {"#Plan1": [], "#Plan2": ["#E1"]}
This function should also return a plan map.
Returns:
Tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
"""
valid_chunk = [
line
for line in planner_response.splitlines()
if line.startswith("#Plan") or line.startswith("#E")
]
plan_to_es: Dict[str, List[str]] = dict()
plans: Dict[str, str] = dict()
for line in valid_chunk:
if line.startswith("#Plan"):
plan = line.split(":", 1)[0].strip()
plans[plan] = line.split(":", 1)[1].strip()
plan_to_es[plan] = []
elif line.startswith("#E"):
plan_to_es[plan].append(line.split(":", 1)[0].strip())
return plan_to_es, plans
def _parse_planner_evidences(
self, planner_response: str
) -> Tuple[Dict[str, str], List[List[str]]]:
"""
Parse planner output. This should return a mapping from #E to tool call.
It should also identify the level of each #E in dependency map.
Example:
{
"#E1": "Tool1", "#E2": "Tool2",
"#E3": "Tool3", "#E4": "Tool4"
}, [[#E1, #E2], [#E3, #E4]]
Returns:
Tuple[dict[str, str], List[List[str]]]:
A mapping from #E to tool call and a list of levels.
"""
evidences: Dict[str, str] = dict()
dependence: Dict[str, List[str]] = dict()
for line in planner_response.splitlines():
if line.startswith("#E") and line[2].isdigit():
e, tool_call = line.split(":", 1)
e, tool_call = e.strip(), tool_call.strip()
if len(e) == 3:
dependence[e] = []
evidences[e] = tool_call
for var in re.findall(r"#E\d+", tool_call):
if var in evidences:
dependence[e].append(var)
else:
evidences[e] = "No evidence found"
level = []
while dependence:
select = [i for i in dependence if not dependence[i]]
if len(select) == 0:
raise ValueError("Circular dependency detected.")
level.append(select)
for item in select:
dependence.pop(item)
for item in dependence:
for i in select:
if i in dependence[item]:
dependence[item].remove(i)
return evidences, level
def _run_plugin(
self,
e: str,
planner_evidences: Dict[str, str],
worker_evidences: Dict[str, str],
output=BaseScratchPad(),
):
"""
Run a plugin for a given evidence.
This function should also cumulate the cost and tokens.
"""
result = dict(e=e, plugin_cost=0, plugin_token=0, evidence="")
tool_call = planner_evidences[e]
if "[" not in tool_call:
result["evidence"] = tool_call
else:
tool, tool_input = tool_call.split("[", 1)
tool_input = tool_input[:-1]
# find variables in input and replace with previous evidences
for var in re.findall(r"#E\d+", tool_input):
if var in worker_evidences:
tool_input = tool_input.replace(var, worker_evidences.get(var, ""))
try:
selected_plugin = self._find_plugin(tool)
if selected_plugin is None:
raise ValueError("Invalid plugin detected")
tool_response = selected_plugin(tool_input)
# cumulate agent-as-plugin costs and tokens.
if isinstance(tool_response, AgentOutput):
result["plugin_cost"] = tool_response.cost
result["plugin_token"] = tool_response.token_usage
result["evidence"] = get_plugin_response_content(tool_response)
except ValueError:
result["evidence"] = "No evidence found."
finally:
output.panel_print(
result["evidence"], f"[green] Function Response of [blue]{tool}: "
)
return result
def _get_worker_evidence(
self,
planner_evidences: Dict[str, str],
evidences_level: List[List[str]],
output=BaseScratchPad(),
) -> Any:
"""
Parallel execution of plugins in DAG for speedup.
This is one of core benefits of ReWOO agents.
Args:
planner_evidences: A mapping from #E to tool call.
evidences_level: A list of levels of evidences.
Calculated from DAG of plugin calls.
output: Output object, defaults to BaseOutput().
Returns:
A mapping from #E to tool call.
"""
worker_evidences: Dict[str, str] = dict()
plugin_cost, plugin_token = 0.0, 0.0
with ThreadPoolExecutor() as pool:
for level in evidences_level:
results = []
for e in level:
results.append(
pool.submit(
self._run_plugin,
e,
planner_evidences,
worker_evidences,
output,
)
)
if len(results) > 1:
output.update_status(f"Running tasks {level} in parallel.")
else:
output.update_status(f"Running task {level[0]}.")
for r in results:
resp = r.result()
plugin_cost += resp["plugin_cost"]
plugin_token += resp["plugin_token"]
worker_evidences[resp["e"]] = resp["evidence"]
output.done()
return worker_evidences, plugin_cost, plugin_token
def _find_plugin(self, name: str):
for p in self.plugins:
if p.name == name:
return p
def _run_tool(self, instruction: str) -> AgentOutput:
"""
Run the agent with a given instruction.
"""
logging.info(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
planner_llm = self._get_llms()["Planner"]
solver_llm = self._get_llms()["Solver"]
planner = Planner(
model=planner_llm,
plugins=self.plugins,
prompt_template=self.prompt_template.get("Planner", None),
examples=self.examples.get("Planner", None),
)
solver = Solver(
model=solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
)
# Plan
planner_output = planner(instruction)
plannner_text_output = planner_output.text[0]
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences(
plannner_text_output
)
# Work
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
planner_evidences, evidence_level
)
worker_log = ""
for plan in plan_to_es:
worker_log += f"{plan}: {plans[plan]}\n"
for e in plan_to_es[plan]:
worker_log += f"{e}: {worker_evidences[e]}\n"
# Solve
solver_output = solver(instruction, worker_log)
solver_output_text = solver_output.text[0]
return AgentOutput(
output=solver_output_text, cost=total_cost, token_usage=total_token
)

View File

@ -0,0 +1,82 @@
from typing import Any, List, Optional, Union
from kotaemon.base import BaseComponent
from kotaemon.prompt.template import PromptTemplate
from ..base import BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from .prompt import zero_shot_planner_prompt
class Planner(BaseComponent):
model: BaseLLM
prompt_template: Optional[PromptTemplate] = None
examples: Optional[Union[str, List[str]]] = None
plugins: List[BaseTool]
def _compose_worker_description(self) -> str:
"""
Compose the worker prompt from the workers.
Example:
toolname1[input]: tool1 description
toolname2[input]: tool2 description
"""
prompt = ""
try:
for worker in self.plugins:
prompt += f"{worker.name}[input]: {worker.description}\n"
except Exception:
raise ValueError("Worker must have a name and description.")
return prompt
def _compose_fewshot_prompt(self) -> str:
if self.examples is None:
return ""
if isinstance(self.examples, str):
return self.examples
else:
return "\n\n".join([e.strip("\n") for e in self.examples])
def _compose_prompt(self, instruction) -> str:
"""
Compose the prompt from template, worker description, examples and instruction.
"""
worker_desctription = self._compose_worker_description()
fewshot = self._compose_fewshot_prompt()
if self.prompt_template is not None:
if "fewshot" in self.prompt_template.placeholders:
return self.prompt_template.populate(
tool_description=worker_desctription,
fewshot=fewshot,
task=instruction,
)
else:
return self.prompt_template.populate(
tool_description=worker_desctription, task=instruction
)
else:
if self.examples is not None:
return zero_shot_planner_prompt.populate(
tool_description=worker_desctription,
fewshot=fewshot,
task=instruction,
)
else:
return zero_shot_planner_prompt.populate(
tool_description=worker_desctription, task=instruction
)
def run(self, instruction: str, output: BaseScratchPad = BaseScratchPad()) -> Any:
response = None
output.info("Running Planner")
prompt = self._compose_prompt(instruction)
output.debug(f"Prompt: {prompt}")
try:
response = self.model(prompt)
output.info("Planner run successful.")
except ValueError:
output.error("Planner failed to retrieve response from LLM")
raise ValueError("Planner failed to retrieve response from LLM")
return response

View File

@ -0,0 +1,119 @@
# flake8: noqa
from kotaemon.prompt.template import PromptTemplate
zero_shot_planner_prompt = PromptTemplate(
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
##Available Tools##
{tool_description}
##Output Format (Replace '<...>')##
#Plan1: <describe your plan here>
#E1: <toolname>[<input here>] (eg. Search[What is Python])
#Plan2: <describe next plan>
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
And so on...
##Your Task##
{task}
##Now Begin##
"""
)
one_shot_planner_prompt = PromptTemplate(
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
##Available Tools##
{tool_description}
##Output Format##
#Plan1: <describe your plan here>
#E1: <toolname>[<input here>]
#Plan2: <describe next plan>
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
And so on...
##Example##
Task: What is the 4th root of 64 to the power of 3?
#Plan1: Find the 4th root of 64
#E1: Calculator[64^(1/4)]
#Plan2: Raise the result from #Plan1 to the power of 3
#E2: Calculator[#E1^3]
##Your Task##
{task}
##Now Begin##
"""
)
few_shot_planner_prompt = PromptTemplate(
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
##Available Tools##
{tool_description}
##Output Format (Replace '<...>')##
#Plan1: <describe your plan here>
#E1: <toolname>[<input>]
#Plan2: <describe next plan>
#E2: <toolname>[<input, you can use #E1 to represent its expected output>]
And so on...
##Examples##
{fewshot}
##Your Task##
{task}
##Now Begin##
"""
)
zero_shot_solver_prompt = PromptTemplate(
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
Your task is to briefly summarize each step, then make a short final conclusion for your task.
##My Plans and Evidences##
{plan_evidence}
##Example Output##
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
So, <your conclusion>.
##Your Task##
{task}
##Now Begin##
"""
)
few_shot_solver_prompt = PromptTemplate(
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
Your task is to briefly summarize each step, then make a short final conclusion for your task.
##My Plans and Evidences##
{plan_evidence}
##Example Output##
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
So, <your conclusion>.
##Example##
{fewshot}
##Your Task##
{task}
##Now Begin##
"""
)

View File

@ -0,0 +1,66 @@
from typing import Any, List, Optional, Union
from kotaemon.base import BaseComponent
from kotaemon.prompt.template import PromptTemplate
from ..base import BaseLLM
from ..output.base import BaseScratchPad
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
class Solver(BaseComponent):
model: BaseLLM
prompt_template: Optional[PromptTemplate] = None
examples: Optional[Union[str, List[str]]] = None
def _compose_fewshot_prompt(self) -> str:
if self.examples is None:
return ""
if isinstance(self.examples, str):
return self.examples
else:
return "\n\n".join([e.strip("\n") for e in self.examples])
def _compose_prompt(self, instruction, plan_evidence) -> str:
"""
Compose the prompt from template, plan&evidence, examples and instruction.
"""
fewshot = self._compose_fewshot_prompt()
if self.prompt_template is not None:
if "fewshot" in self.prompt_template.placeholders:
return self.prompt_template.populate(
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
)
else:
return self.prompt_template.populate(
plan_evidence=plan_evidence, task=instruction
)
else:
if self.examples is not None:
return few_shot_solver_prompt.populate(
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
)
else:
return zero_shot_solver_prompt.populate(
plan_evidence=plan_evidence, task=instruction
)
def run(
self,
instruction: str,
plan_evidence: str,
output: BaseScratchPad = BaseScratchPad(),
) -> Any:
response = None
output.info("Running Solver")
output.debug(f"Instruction: {instruction}")
output.debug(f"Plan Evidence: {plan_evidence}")
prompt = self._compose_prompt(instruction, plan_evidence)
output.debug(f"Prompt: {prompt}")
try:
response = self.model(prompt)
output.info("Solver run successful.")
except ValueError:
output.error("Solver failed to retrieve response from LLM")
return response

View File

@ -0,0 +1,22 @@
from .base import AgentOutput
def get_plugin_response_content(output) -> str:
"""
Wrapper for AgentOutput content return
"""
if isinstance(output, AgentOutput):
return output.output
else:
return str(output)
def calculate_cost(model_name: str, prompt_token: int, completion_token: int) -> float:
"""
Calculate the cost of a prompt and completion.
Returns:
float: Cost of the provided model name with provided token information
"""
# TODO: to be implemented
return 0.0

68
tests/test_agent.py Normal file
View File

@ -0,0 +1,68 @@
from unittest.mock import patch
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents.rewoo import RewooAgent
from kotaemon.pipelines.tools import GoogleSearchTool, WikipediaTool
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
_openai_chat_completion_responses = [
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": "#Plan1: Search for Cinnamon AI company on Google\n"
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]",
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
},
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": FINAL_RESPONSE_TEXT,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
},
]
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses,
)
def test_rewoo_agent(openai_completion):
llm = AzureChatOpenAI(
openai_api_base="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2",
temperature=0,
)
plugins = [GoogleSearchTool(), WikipediaTool()]
agent = RewooAgent(llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT