[AUR-431, AUR-435] Add Agent Interface and ReWOO Agent implementation (#31)
* add base Tool * minor update test_tool * update test dependency * update test dependency * Fix namespace conflict * update test * add base Agent Interface, add ReWoo Agent * minor update * update test * fix typo * remove unneeded print * update rewoo agent --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
parent
f9fc02a32a
commit
91048770fa
3
knowledgehub/pipelines/agents/__init__.py
Normal file
3
knowledgehub/pipelines/agents/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .base import BaseAgent
|
||||||
|
|
||||||
|
__all__ = ["BaseAgent"]
|
67
knowledgehub/pipelines/agents/base.py
Normal file
67
knowledgehub/pipelines/agents/base.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from kotaemon.llms.chats.base import ChatLLM
|
||||||
|
from kotaemon.llms.completions.base import LLM
|
||||||
|
from kotaemon.pipelines.tools import BaseTool
|
||||||
|
from kotaemon.prompt.template import PromptTemplate
|
||||||
|
|
||||||
|
BaseLLM = Union[ChatLLM, LLM]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(Enum):
|
||||||
|
"""
|
||||||
|
Enumerated type for agent types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
openai = "openai"
|
||||||
|
react = "react"
|
||||||
|
rewoo = "rewoo"
|
||||||
|
vanilla = "vanilla"
|
||||||
|
openai_memory = "openai_memory"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_agent_class(_type: "AgentType"):
|
||||||
|
"""
|
||||||
|
Get agent class from agent type.
|
||||||
|
:param _type: agent type
|
||||||
|
:return: agent class
|
||||||
|
"""
|
||||||
|
if _type == AgentType.rewoo:
|
||||||
|
from .rewoo.agent import RewooAgent
|
||||||
|
|
||||||
|
return RewooAgent
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown agent type: {_type}")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutput(BaseModel):
|
||||||
|
"""
|
||||||
|
Pydantic model for agent output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output: str
|
||||||
|
cost: float
|
||||||
|
token_usage: int
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(BaseTool):
|
||||||
|
name: str
|
||||||
|
"""Name of the agent."""
|
||||||
|
type: AgentType
|
||||||
|
"""Agent type, must be one of AgentType"""
|
||||||
|
description: str
|
||||||
|
"""Description used to tell the model how/when/why to use the agent.
|
||||||
|
You can provide few-shot examples as a part of the description. This will be
|
||||||
|
input to the prompt of LLM."""
|
||||||
|
llm: Union[BaseLLM, Dict[str, BaseLLM]]
|
||||||
|
"""Specify LLM to be used in the model, cam be a dict to supply different
|
||||||
|
LLMs to multiple purposes in the agent"""
|
||||||
|
prompt_template: Union[PromptTemplate, Dict[str, PromptTemplate]]
|
||||||
|
"""A prompt template or a dict to supply different prompt to the agent
|
||||||
|
"""
|
||||||
|
plugins: List[BaseTool]
|
||||||
|
"""List of plugins / tools to be used in the agent
|
||||||
|
"""
|
0
knowledgehub/pipelines/agents/output/__init__.py
Normal file
0
knowledgehub/pipelines/agents/output/__init__.py
Normal file
219
knowledgehub/pipelines/agents/output/base.py
Normal file
219
knowledgehub/pipelines/agents/output/base.py
Normal file
|
@ -0,0 +1,219 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, NamedTuple, Union
|
||||||
|
|
||||||
|
|
||||||
|
def check_log():
|
||||||
|
"""
|
||||||
|
Checks if logging has been enabled.
|
||||||
|
:return: True if logging has been enabled, False otherwise.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return os.environ.get("LOG_PATH", None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseScratchPad:
|
||||||
|
"""
|
||||||
|
Base class for output handlers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
-----------
|
||||||
|
logger : logging.Logger
|
||||||
|
The logger object to log messages.
|
||||||
|
|
||||||
|
Methods:
|
||||||
|
--------
|
||||||
|
stop():
|
||||||
|
Stop the output.
|
||||||
|
|
||||||
|
update_status(output: str, **kwargs):
|
||||||
|
Update the status of the output.
|
||||||
|
|
||||||
|
thinking(name: str):
|
||||||
|
Log that a process is thinking.
|
||||||
|
|
||||||
|
done(_all=False):
|
||||||
|
Log that the process is done.
|
||||||
|
|
||||||
|
stream_print(item: str):
|
||||||
|
Not implemented.
|
||||||
|
|
||||||
|
json_print(item: Dict[str, Any]):
|
||||||
|
Log a JSON object.
|
||||||
|
|
||||||
|
panel_print(item: Any, title: str = "Output", stream: bool = False):
|
||||||
|
Log a panel output.
|
||||||
|
|
||||||
|
clear():
|
||||||
|
Not implemented.
|
||||||
|
|
||||||
|
print(content: str, **kwargs):
|
||||||
|
Log arbitrary content.
|
||||||
|
|
||||||
|
format_json(json_obj: str):
|
||||||
|
Format a JSON object.
|
||||||
|
|
||||||
|
debug(content: str, **kwargs):
|
||||||
|
Log a debug message.
|
||||||
|
|
||||||
|
info(content: str, **kwargs):
|
||||||
|
Log an informational message.
|
||||||
|
|
||||||
|
warning(content: str, **kwargs):
|
||||||
|
Log a warning message.
|
||||||
|
|
||||||
|
error(content: str, **kwargs):
|
||||||
|
Log an error message.
|
||||||
|
|
||||||
|
critical(content: str, **kwargs):
|
||||||
|
Log a critical message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Initialize the BaseOutput object.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.logger = logging
|
||||||
|
self.log = []
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
Stop the output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def update_status(self, output: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Update the status of the output.
|
||||||
|
"""
|
||||||
|
if check_log():
|
||||||
|
self.logger.info(output)
|
||||||
|
|
||||||
|
def thinking(self, name: str):
|
||||||
|
"""
|
||||||
|
Log that a process is thinking.
|
||||||
|
"""
|
||||||
|
if check_log():
|
||||||
|
self.logger.info(f"{name} is thinking...")
|
||||||
|
|
||||||
|
def done(self, _all=False):
|
||||||
|
"""
|
||||||
|
Log that the process is done.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if check_log():
|
||||||
|
self.logger.info("Done")
|
||||||
|
|
||||||
|
def stream_print(self, item: str):
|
||||||
|
"""
|
||||||
|
Stream print.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def json_print(self, item: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Log a JSON object.
|
||||||
|
"""
|
||||||
|
if check_log():
|
||||||
|
self.logger.info(json.dumps(item, indent=2))
|
||||||
|
|
||||||
|
def panel_print(self, item: Any, title: str = "Output", stream: bool = False):
|
||||||
|
"""
|
||||||
|
Log a panel output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item : Any
|
||||||
|
The item to log.
|
||||||
|
title : str, optional
|
||||||
|
The title of the panel, defaults to "Output".
|
||||||
|
stream : bool, optional
|
||||||
|
"""
|
||||||
|
if not stream:
|
||||||
|
self.log.append(item)
|
||||||
|
if check_log():
|
||||||
|
self.logger.info("-" * 20)
|
||||||
|
self.logger.info(item)
|
||||||
|
self.logger.info("-" * 20)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""
|
||||||
|
Not implemented.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def print(self, content: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Log arbitrary content.
|
||||||
|
"""
|
||||||
|
self.log.append(content)
|
||||||
|
if check_log():
|
||||||
|
self.logger.info(content)
|
||||||
|
|
||||||
|
def format_json(self, json_obj: str):
|
||||||
|
"""
|
||||||
|
Format a JSON object.
|
||||||
|
"""
|
||||||
|
formatted_json = json.dumps(json_obj, indent=2)
|
||||||
|
return formatted_json
|
||||||
|
|
||||||
|
def debug(self, content: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Log a debug message.
|
||||||
|
"""
|
||||||
|
if check_log():
|
||||||
|
self.logger.debug(content, **kwargs)
|
||||||
|
|
||||||
|
def info(self, content: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Log an informational message.
|
||||||
|
"""
|
||||||
|
if check_log():
|
||||||
|
self.logger.info(content, **kwargs)
|
||||||
|
|
||||||
|
def warning(self, content: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Log a warning message.
|
||||||
|
"""
|
||||||
|
if check_log():
|
||||||
|
self.logger.warning(content, **kwargs)
|
||||||
|
|
||||||
|
def error(self, content: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Log an error message.
|
||||||
|
"""
|
||||||
|
if check_log():
|
||||||
|
self.logger.error(content, **kwargs)
|
||||||
|
|
||||||
|
def critical(self, content: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Log a critical message.
|
||||||
|
"""
|
||||||
|
if check_log():
|
||||||
|
self.logger.critical(content, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentAction:
|
||||||
|
"""Agent's action to take.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool: The tool to invoke.
|
||||||
|
tool_input: The input to the tool.
|
||||||
|
log: The log message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tool: str
|
||||||
|
tool_input: Union[str, dict]
|
||||||
|
log: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFinish(NamedTuple):
|
||||||
|
"""Agent's return value when finishing execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
return_values: The return values of the agent.
|
||||||
|
log: The log message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return_values: dict
|
||||||
|
log: str
|
3
knowledgehub/pipelines/agents/rewoo/__init__.py
Normal file
3
knowledgehub/pipelines/agents/rewoo/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .agent import RewooAgent
|
||||||
|
|
||||||
|
__all__ = ["RewooAgent"]
|
270
knowledgehub/pipelines/agents/rewoo/agent.py
Normal file
270
knowledgehub/pipelines/agents/rewoo/agent.py
Normal file
|
@ -0,0 +1,270 @@
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, create_model
|
||||||
|
|
||||||
|
from kotaemon.llms.chats.base import ChatLLM
|
||||||
|
from kotaemon.llms.completions.base import LLM
|
||||||
|
from kotaemon.prompt.template import PromptTemplate
|
||||||
|
|
||||||
|
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool
|
||||||
|
from ..output.base import BaseScratchPad
|
||||||
|
from ..utils import get_plugin_response_content
|
||||||
|
from .planner import Planner
|
||||||
|
from .solver import Solver
|
||||||
|
|
||||||
|
|
||||||
|
class RewooAgent(BaseAgent):
|
||||||
|
"""Distributive RewooAgent class inherited from BaseAgent.
|
||||||
|
Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf"""
|
||||||
|
|
||||||
|
name: str = "RewooAgent"
|
||||||
|
type: AgentType = AgentType.rewoo
|
||||||
|
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||||
|
llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx}
|
||||||
|
prompt_template: Dict[
|
||||||
|
str, PromptTemplate
|
||||||
|
] = dict() # {"Planner": xxx, "Solver": xxx}
|
||||||
|
plugins: List[BaseTool] = list()
|
||||||
|
examples: Dict[str, Union[str, List[str]]] = dict()
|
||||||
|
args_schema: Optional[Type[BaseModel]] = create_model(
|
||||||
|
"ReactArgsSchema", instruction=(str, ...)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_llms(self):
|
||||||
|
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM):
|
||||||
|
return {"Planner": self.llm, "Solver": self.llm}
|
||||||
|
elif (
|
||||||
|
isinstance(self.llm, dict)
|
||||||
|
and "Planner" in self.llm
|
||||||
|
and "Solver" in self.llm
|
||||||
|
):
|
||||||
|
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
|
||||||
|
else:
|
||||||
|
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
|
||||||
|
|
||||||
|
def _parse_plan_map(
|
||||||
|
self, planner_response: str
|
||||||
|
) -> Tuple[Dict[str, List[str]], Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Parse planner output. It should be an n-to-n mapping from Plans to #Es.
|
||||||
|
This is because sometimes LLM cannot follow the strict output format.
|
||||||
|
Example:
|
||||||
|
#Plan1
|
||||||
|
#E1
|
||||||
|
#E2
|
||||||
|
should result in: {"#Plan1": ["#E1", "#E2"]}
|
||||||
|
Or:
|
||||||
|
#Plan1
|
||||||
|
#Plan2
|
||||||
|
#E1
|
||||||
|
should result in: {"#Plan1": [], "#Plan2": ["#E1"]}
|
||||||
|
This function should also return a plan map.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
|
||||||
|
"""
|
||||||
|
valid_chunk = [
|
||||||
|
line
|
||||||
|
for line in planner_response.splitlines()
|
||||||
|
if line.startswith("#Plan") or line.startswith("#E")
|
||||||
|
]
|
||||||
|
|
||||||
|
plan_to_es: Dict[str, List[str]] = dict()
|
||||||
|
plans: Dict[str, str] = dict()
|
||||||
|
for line in valid_chunk:
|
||||||
|
if line.startswith("#Plan"):
|
||||||
|
plan = line.split(":", 1)[0].strip()
|
||||||
|
plans[plan] = line.split(":", 1)[1].strip()
|
||||||
|
plan_to_es[plan] = []
|
||||||
|
elif line.startswith("#E"):
|
||||||
|
plan_to_es[plan].append(line.split(":", 1)[0].strip())
|
||||||
|
|
||||||
|
return plan_to_es, plans
|
||||||
|
|
||||||
|
def _parse_planner_evidences(
|
||||||
|
self, planner_response: str
|
||||||
|
) -> Tuple[Dict[str, str], List[List[str]]]:
|
||||||
|
"""
|
||||||
|
Parse planner output. This should return a mapping from #E to tool call.
|
||||||
|
It should also identify the level of each #E in dependency map.
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"#E1": "Tool1", "#E2": "Tool2",
|
||||||
|
"#E3": "Tool3", "#E4": "Tool4"
|
||||||
|
}, [[#E1, #E2], [#E3, #E4]]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[dict[str, str], List[List[str]]]:
|
||||||
|
A mapping from #E to tool call and a list of levels.
|
||||||
|
"""
|
||||||
|
evidences: Dict[str, str] = dict()
|
||||||
|
dependence: Dict[str, List[str]] = dict()
|
||||||
|
for line in planner_response.splitlines():
|
||||||
|
if line.startswith("#E") and line[2].isdigit():
|
||||||
|
e, tool_call = line.split(":", 1)
|
||||||
|
e, tool_call = e.strip(), tool_call.strip()
|
||||||
|
if len(e) == 3:
|
||||||
|
dependence[e] = []
|
||||||
|
evidences[e] = tool_call
|
||||||
|
for var in re.findall(r"#E\d+", tool_call):
|
||||||
|
if var in evidences:
|
||||||
|
dependence[e].append(var)
|
||||||
|
else:
|
||||||
|
evidences[e] = "No evidence found"
|
||||||
|
level = []
|
||||||
|
while dependence:
|
||||||
|
select = [i for i in dependence if not dependence[i]]
|
||||||
|
if len(select) == 0:
|
||||||
|
raise ValueError("Circular dependency detected.")
|
||||||
|
level.append(select)
|
||||||
|
for item in select:
|
||||||
|
dependence.pop(item)
|
||||||
|
for item in dependence:
|
||||||
|
for i in select:
|
||||||
|
if i in dependence[item]:
|
||||||
|
dependence[item].remove(i)
|
||||||
|
|
||||||
|
return evidences, level
|
||||||
|
|
||||||
|
def _run_plugin(
|
||||||
|
self,
|
||||||
|
e: str,
|
||||||
|
planner_evidences: Dict[str, str],
|
||||||
|
worker_evidences: Dict[str, str],
|
||||||
|
output=BaseScratchPad(),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run a plugin for a given evidence.
|
||||||
|
This function should also cumulate the cost and tokens.
|
||||||
|
"""
|
||||||
|
result = dict(e=e, plugin_cost=0, plugin_token=0, evidence="")
|
||||||
|
tool_call = planner_evidences[e]
|
||||||
|
if "[" not in tool_call:
|
||||||
|
result["evidence"] = tool_call
|
||||||
|
else:
|
||||||
|
tool, tool_input = tool_call.split("[", 1)
|
||||||
|
tool_input = tool_input[:-1]
|
||||||
|
# find variables in input and replace with previous evidences
|
||||||
|
for var in re.findall(r"#E\d+", tool_input):
|
||||||
|
if var in worker_evidences:
|
||||||
|
tool_input = tool_input.replace(var, worker_evidences.get(var, ""))
|
||||||
|
try:
|
||||||
|
selected_plugin = self._find_plugin(tool)
|
||||||
|
if selected_plugin is None:
|
||||||
|
raise ValueError("Invalid plugin detected")
|
||||||
|
tool_response = selected_plugin(tool_input)
|
||||||
|
# cumulate agent-as-plugin costs and tokens.
|
||||||
|
if isinstance(tool_response, AgentOutput):
|
||||||
|
result["plugin_cost"] = tool_response.cost
|
||||||
|
result["plugin_token"] = tool_response.token_usage
|
||||||
|
result["evidence"] = get_plugin_response_content(tool_response)
|
||||||
|
except ValueError:
|
||||||
|
result["evidence"] = "No evidence found."
|
||||||
|
finally:
|
||||||
|
output.panel_print(
|
||||||
|
result["evidence"], f"[green] Function Response of [blue]{tool}: "
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_worker_evidence(
|
||||||
|
self,
|
||||||
|
planner_evidences: Dict[str, str],
|
||||||
|
evidences_level: List[List[str]],
|
||||||
|
output=BaseScratchPad(),
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Parallel execution of plugins in DAG for speedup.
|
||||||
|
This is one of core benefits of ReWOO agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
planner_evidences: A mapping from #E to tool call.
|
||||||
|
evidences_level: A list of levels of evidences.
|
||||||
|
Calculated from DAG of plugin calls.
|
||||||
|
output: Output object, defaults to BaseOutput().
|
||||||
|
Returns:
|
||||||
|
A mapping from #E to tool call.
|
||||||
|
"""
|
||||||
|
worker_evidences: Dict[str, str] = dict()
|
||||||
|
plugin_cost, plugin_token = 0.0, 0.0
|
||||||
|
with ThreadPoolExecutor() as pool:
|
||||||
|
for level in evidences_level:
|
||||||
|
results = []
|
||||||
|
for e in level:
|
||||||
|
results.append(
|
||||||
|
pool.submit(
|
||||||
|
self._run_plugin,
|
||||||
|
e,
|
||||||
|
planner_evidences,
|
||||||
|
worker_evidences,
|
||||||
|
output,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if len(results) > 1:
|
||||||
|
output.update_status(f"Running tasks {level} in parallel.")
|
||||||
|
else:
|
||||||
|
output.update_status(f"Running task {level[0]}.")
|
||||||
|
for r in results:
|
||||||
|
resp = r.result()
|
||||||
|
plugin_cost += resp["plugin_cost"]
|
||||||
|
plugin_token += resp["plugin_token"]
|
||||||
|
worker_evidences[resp["e"]] = resp["evidence"]
|
||||||
|
output.done()
|
||||||
|
|
||||||
|
return worker_evidences, plugin_cost, plugin_token
|
||||||
|
|
||||||
|
def _find_plugin(self, name: str):
|
||||||
|
for p in self.plugins:
|
||||||
|
if p.name == name:
|
||||||
|
return p
|
||||||
|
|
||||||
|
def _run_tool(self, instruction: str) -> AgentOutput:
|
||||||
|
"""
|
||||||
|
Run the agent with a given instruction.
|
||||||
|
"""
|
||||||
|
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||||
|
total_cost = 0.0
|
||||||
|
total_token = 0
|
||||||
|
|
||||||
|
planner_llm = self._get_llms()["Planner"]
|
||||||
|
solver_llm = self._get_llms()["Solver"]
|
||||||
|
|
||||||
|
planner = Planner(
|
||||||
|
model=planner_llm,
|
||||||
|
plugins=self.plugins,
|
||||||
|
prompt_template=self.prompt_template.get("Planner", None),
|
||||||
|
examples=self.examples.get("Planner", None),
|
||||||
|
)
|
||||||
|
solver = Solver(
|
||||||
|
model=solver_llm,
|
||||||
|
prompt_template=self.prompt_template.get("Solver", None),
|
||||||
|
examples=self.examples.get("Solver", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plan
|
||||||
|
planner_output = planner(instruction)
|
||||||
|
plannner_text_output = planner_output.text[0]
|
||||||
|
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
|
||||||
|
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||||
|
plannner_text_output
|
||||||
|
)
|
||||||
|
|
||||||
|
# Work
|
||||||
|
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
|
||||||
|
planner_evidences, evidence_level
|
||||||
|
)
|
||||||
|
worker_log = ""
|
||||||
|
for plan in plan_to_es:
|
||||||
|
worker_log += f"{plan}: {plans[plan]}\n"
|
||||||
|
for e in plan_to_es[plan]:
|
||||||
|
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||||
|
|
||||||
|
# Solve
|
||||||
|
solver_output = solver(instruction, worker_log)
|
||||||
|
solver_output_text = solver_output.text[0]
|
||||||
|
|
||||||
|
return AgentOutput(
|
||||||
|
output=solver_output_text, cost=total_cost, token_usage=total_token
|
||||||
|
)
|
82
knowledgehub/pipelines/agents/rewoo/planner.py
Normal file
82
knowledgehub/pipelines/agents/rewoo/planner.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.prompt.template import PromptTemplate
|
||||||
|
|
||||||
|
from ..base import BaseLLM, BaseTool
|
||||||
|
from ..output.base import BaseScratchPad
|
||||||
|
from .prompt import zero_shot_planner_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class Planner(BaseComponent):
|
||||||
|
model: BaseLLM
|
||||||
|
prompt_template: Optional[PromptTemplate] = None
|
||||||
|
examples: Optional[Union[str, List[str]]] = None
|
||||||
|
plugins: List[BaseTool]
|
||||||
|
|
||||||
|
def _compose_worker_description(self) -> str:
|
||||||
|
"""
|
||||||
|
Compose the worker prompt from the workers.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
toolname1[input]: tool1 description
|
||||||
|
toolname2[input]: tool2 description
|
||||||
|
"""
|
||||||
|
prompt = ""
|
||||||
|
try:
|
||||||
|
for worker in self.plugins:
|
||||||
|
prompt += f"{worker.name}[input]: {worker.description}\n"
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Worker must have a name and description.")
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _compose_fewshot_prompt(self) -> str:
|
||||||
|
if self.examples is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(self.examples, str):
|
||||||
|
return self.examples
|
||||||
|
else:
|
||||||
|
return "\n\n".join([e.strip("\n") for e in self.examples])
|
||||||
|
|
||||||
|
def _compose_prompt(self, instruction) -> str:
|
||||||
|
"""
|
||||||
|
Compose the prompt from template, worker description, examples and instruction.
|
||||||
|
"""
|
||||||
|
worker_desctription = self._compose_worker_description()
|
||||||
|
fewshot = self._compose_fewshot_prompt()
|
||||||
|
if self.prompt_template is not None:
|
||||||
|
if "fewshot" in self.prompt_template.placeholders:
|
||||||
|
return self.prompt_template.populate(
|
||||||
|
tool_description=worker_desctription,
|
||||||
|
fewshot=fewshot,
|
||||||
|
task=instruction,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.prompt_template.populate(
|
||||||
|
tool_description=worker_desctription, task=instruction
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.examples is not None:
|
||||||
|
return zero_shot_planner_prompt.populate(
|
||||||
|
tool_description=worker_desctription,
|
||||||
|
fewshot=fewshot,
|
||||||
|
task=instruction,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return zero_shot_planner_prompt.populate(
|
||||||
|
tool_description=worker_desctription, task=instruction
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(self, instruction: str, output: BaseScratchPad = BaseScratchPad()) -> Any:
|
||||||
|
response = None
|
||||||
|
output.info("Running Planner")
|
||||||
|
prompt = self._compose_prompt(instruction)
|
||||||
|
output.debug(f"Prompt: {prompt}")
|
||||||
|
try:
|
||||||
|
response = self.model(prompt)
|
||||||
|
output.info("Planner run successful.")
|
||||||
|
except ValueError:
|
||||||
|
output.error("Planner failed to retrieve response from LLM")
|
||||||
|
raise ValueError("Planner failed to retrieve response from LLM")
|
||||||
|
|
||||||
|
return response
|
119
knowledgehub/pipelines/agents/rewoo/prompt.py
Normal file
119
knowledgehub/pipelines/agents/rewoo/prompt.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
# flake8: noqa
|
||||||
|
|
||||||
|
from kotaemon.prompt.template import PromptTemplate
|
||||||
|
|
||||||
|
zero_shot_planner_prompt = PromptTemplate(
|
||||||
|
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||||
|
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||||
|
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||||
|
|
||||||
|
##Available Tools##
|
||||||
|
{tool_description}
|
||||||
|
|
||||||
|
##Output Format (Replace '<...>')##
|
||||||
|
#Plan1: <describe your plan here>
|
||||||
|
#E1: <toolname>[<input here>] (eg. Search[What is Python])
|
||||||
|
#Plan2: <describe next plan>
|
||||||
|
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
|
||||||
|
And so on...
|
||||||
|
|
||||||
|
##Your Task##
|
||||||
|
{task}
|
||||||
|
|
||||||
|
##Now Begin##
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
one_shot_planner_prompt = PromptTemplate(
|
||||||
|
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||||
|
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||||
|
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||||
|
|
||||||
|
##Available Tools##
|
||||||
|
{tool_description}
|
||||||
|
|
||||||
|
##Output Format##
|
||||||
|
#Plan1: <describe your plan here>
|
||||||
|
#E1: <toolname>[<input here>]
|
||||||
|
#Plan2: <describe next plan>
|
||||||
|
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
|
||||||
|
And so on...
|
||||||
|
|
||||||
|
##Example##
|
||||||
|
Task: What is the 4th root of 64 to the power of 3?
|
||||||
|
#Plan1: Find the 4th root of 64
|
||||||
|
#E1: Calculator[64^(1/4)]
|
||||||
|
#Plan2: Raise the result from #Plan1 to the power of 3
|
||||||
|
#E2: Calculator[#E1^3]
|
||||||
|
|
||||||
|
##Your Task##
|
||||||
|
{task}
|
||||||
|
|
||||||
|
##Now Begin##
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
few_shot_planner_prompt = PromptTemplate(
|
||||||
|
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||||
|
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||||
|
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||||
|
|
||||||
|
##Available Tools##
|
||||||
|
{tool_description}
|
||||||
|
|
||||||
|
##Output Format (Replace '<...>')##
|
||||||
|
#Plan1: <describe your plan here>
|
||||||
|
#E1: <toolname>[<input>]
|
||||||
|
#Plan2: <describe next plan>
|
||||||
|
#E2: <toolname>[<input, you can use #E1 to represent its expected output>]
|
||||||
|
And so on...
|
||||||
|
|
||||||
|
##Examples##
|
||||||
|
{fewshot}
|
||||||
|
|
||||||
|
##Your Task##
|
||||||
|
{task}
|
||||||
|
|
||||||
|
##Now Begin##
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
zero_shot_solver_prompt = PromptTemplate(
|
||||||
|
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
|
||||||
|
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
||||||
|
|
||||||
|
##My Plans and Evidences##
|
||||||
|
{plan_evidence}
|
||||||
|
|
||||||
|
##Example Output##
|
||||||
|
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
|
||||||
|
So, <your conclusion>.
|
||||||
|
|
||||||
|
##Your Task##
|
||||||
|
{task}
|
||||||
|
|
||||||
|
##Now Begin##
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
few_shot_solver_prompt = PromptTemplate(
|
||||||
|
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
|
||||||
|
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
||||||
|
|
||||||
|
##My Plans and Evidences##
|
||||||
|
{plan_evidence}
|
||||||
|
|
||||||
|
##Example Output##
|
||||||
|
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
|
||||||
|
So, <your conclusion>.
|
||||||
|
|
||||||
|
##Example##
|
||||||
|
{fewshot}
|
||||||
|
|
||||||
|
##Your Task##
|
||||||
|
{task}
|
||||||
|
|
||||||
|
##Now Begin##
|
||||||
|
"""
|
||||||
|
)
|
66
knowledgehub/pipelines/agents/rewoo/solver.py
Normal file
66
knowledgehub/pipelines/agents/rewoo/solver.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.prompt.template import PromptTemplate
|
||||||
|
|
||||||
|
from ..base import BaseLLM
|
||||||
|
from ..output.base import BaseScratchPad
|
||||||
|
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
|
||||||
|
|
||||||
|
|
||||||
|
class Solver(BaseComponent):
|
||||||
|
model: BaseLLM
|
||||||
|
prompt_template: Optional[PromptTemplate] = None
|
||||||
|
examples: Optional[Union[str, List[str]]] = None
|
||||||
|
|
||||||
|
def _compose_fewshot_prompt(self) -> str:
|
||||||
|
if self.examples is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(self.examples, str):
|
||||||
|
return self.examples
|
||||||
|
else:
|
||||||
|
return "\n\n".join([e.strip("\n") for e in self.examples])
|
||||||
|
|
||||||
|
def _compose_prompt(self, instruction, plan_evidence) -> str:
|
||||||
|
"""
|
||||||
|
Compose the prompt from template, plan&evidence, examples and instruction.
|
||||||
|
"""
|
||||||
|
fewshot = self._compose_fewshot_prompt()
|
||||||
|
if self.prompt_template is not None:
|
||||||
|
if "fewshot" in self.prompt_template.placeholders:
|
||||||
|
return self.prompt_template.populate(
|
||||||
|
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.prompt_template.populate(
|
||||||
|
plan_evidence=plan_evidence, task=instruction
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.examples is not None:
|
||||||
|
return few_shot_solver_prompt.populate(
|
||||||
|
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return zero_shot_solver_prompt.populate(
|
||||||
|
plan_evidence=plan_evidence, task=instruction
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
plan_evidence: str,
|
||||||
|
output: BaseScratchPad = BaseScratchPad(),
|
||||||
|
) -> Any:
|
||||||
|
response = None
|
||||||
|
output.info("Running Solver")
|
||||||
|
output.debug(f"Instruction: {instruction}")
|
||||||
|
output.debug(f"Plan Evidence: {plan_evidence}")
|
||||||
|
prompt = self._compose_prompt(instruction, plan_evidence)
|
||||||
|
output.debug(f"Prompt: {prompt}")
|
||||||
|
try:
|
||||||
|
response = self.model(prompt)
|
||||||
|
output.info("Solver run successful.")
|
||||||
|
except ValueError:
|
||||||
|
output.error("Solver failed to retrieve response from LLM")
|
||||||
|
|
||||||
|
return response
|
22
knowledgehub/pipelines/agents/utils.py
Normal file
22
knowledgehub/pipelines/agents/utils.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
from .base import AgentOutput
|
||||||
|
|
||||||
|
|
||||||
|
def get_plugin_response_content(output) -> str:
|
||||||
|
"""
|
||||||
|
Wrapper for AgentOutput content return
|
||||||
|
"""
|
||||||
|
if isinstance(output, AgentOutput):
|
||||||
|
return output.output
|
||||||
|
else:
|
||||||
|
return str(output)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_cost(model_name: str, prompt_token: int, completion_token: int) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the cost of a prompt and completion.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Cost of the provided model name with provided token information
|
||||||
|
"""
|
||||||
|
# TODO: to be implemented
|
||||||
|
return 0.0
|
68
tests/test_agent.py
Normal file
68
tests/test_agent.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.pipelines.agents.rewoo import RewooAgent
|
||||||
|
from kotaemon.pipelines.tools import GoogleSearchTool, WikipediaTool
|
||||||
|
|
||||||
|
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
|
||||||
|
_openai_chat_completion_responses = [
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1692338378,
|
||||||
|
"model": "gpt-35-turbo",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "#Plan1: Search for Cinnamon AI company on Google\n"
|
||||||
|
"#E1: google_search[Cinnamon AI company]\n"
|
||||||
|
"#Plan2: Search for Cinnamon on Wikipedia\n"
|
||||||
|
"#E2: wikipedia[Cinnamon]",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1692338378,
|
||||||
|
"model": "gpt-35-turbo",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": FINAL_RESPONSE_TEXT,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||||
|
side_effect=_openai_chat_completion_responses,
|
||||||
|
)
|
||||||
|
def test_rewoo_agent(openai_completion):
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base="https://dummy.openai.azure.com/",
|
||||||
|
openai_api_key="dummy",
|
||||||
|
openai_api_version="2023-03-15-preview",
|
||||||
|
deployment_name="dummy-q2",
|
||||||
|
temperature=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
plugins = [GoogleSearchTool(), WikipediaTool()]
|
||||||
|
|
||||||
|
agent = RewooAgent(llm=llm, plugins=plugins)
|
||||||
|
|
||||||
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
|
openai_completion.assert_called()
|
||||||
|
assert response.output == FINAL_RESPONSE_TEXT
|
Loading…
Reference in New Issue
Block a user