diff --git a/knowledgehub/pipelines/agents/__init__.py b/knowledgehub/pipelines/agents/__init__.py new file mode 100644 index 0000000..3a12096 --- /dev/null +++ b/knowledgehub/pipelines/agents/__init__.py @@ -0,0 +1,3 @@ +from .base import BaseAgent + +__all__ = ["BaseAgent"] diff --git a/knowledgehub/pipelines/agents/base.py b/knowledgehub/pipelines/agents/base.py new file mode 100644 index 0000000..cb05546 --- /dev/null +++ b/knowledgehub/pipelines/agents/base.py @@ -0,0 +1,67 @@ +from enum import Enum +from typing import Dict, List, Union + +from pydantic import BaseModel + +from kotaemon.llms.chats.base import ChatLLM +from kotaemon.llms.completions.base import LLM +from kotaemon.pipelines.tools import BaseTool +from kotaemon.prompt.template import PromptTemplate + +BaseLLM = Union[ChatLLM, LLM] + + +class AgentType(Enum): + """ + Enumerated type for agent types. + """ + + openai = "openai" + react = "react" + rewoo = "rewoo" + vanilla = "vanilla" + openai_memory = "openai_memory" + + @staticmethod + def get_agent_class(_type: "AgentType"): + """ + Get agent class from agent type. + :param _type: agent type + :return: agent class + """ + if _type == AgentType.rewoo: + from .rewoo.agent import RewooAgent + + return RewooAgent + else: + raise ValueError(f"Unknown agent type: {_type}") + + +class AgentOutput(BaseModel): + """ + Pydantic model for agent output. + """ + + output: str + cost: float + token_usage: int + + +class BaseAgent(BaseTool): + name: str + """Name of the agent.""" + type: AgentType + """Agent type, must be one of AgentType""" + description: str + """Description used to tell the model how/when/why to use the agent. + You can provide few-shot examples as a part of the description. This will be + input to the prompt of LLM.""" + llm: Union[BaseLLM, Dict[str, BaseLLM]] + """Specify LLM to be used in the model, cam be a dict to supply different + LLMs to multiple purposes in the agent""" + prompt_template: Union[PromptTemplate, Dict[str, PromptTemplate]] + """A prompt template or a dict to supply different prompt to the agent + """ + plugins: List[BaseTool] + """List of plugins / tools to be used in the agent + """ diff --git a/knowledgehub/pipelines/agents/output/__init__.py b/knowledgehub/pipelines/agents/output/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/knowledgehub/pipelines/agents/output/base.py b/knowledgehub/pipelines/agents/output/base.py new file mode 100644 index 0000000..242daa5 --- /dev/null +++ b/knowledgehub/pipelines/agents/output/base.py @@ -0,0 +1,219 @@ +import json +import logging +import os +from dataclasses import dataclass +from typing import Any, Dict, NamedTuple, Union + + +def check_log(): + """ + Checks if logging has been enabled. + :return: True if logging has been enabled, False otherwise. + :rtype: bool + """ + return os.environ.get("LOG_PATH", None) is not None + + +class BaseScratchPad: + """ + Base class for output handlers. + + Attributes: + ----------- + logger : logging.Logger + The logger object to log messages. + + Methods: + -------- + stop(): + Stop the output. + + update_status(output: str, **kwargs): + Update the status of the output. + + thinking(name: str): + Log that a process is thinking. + + done(_all=False): + Log that the process is done. + + stream_print(item: str): + Not implemented. + + json_print(item: Dict[str, Any]): + Log a JSON object. + + panel_print(item: Any, title: str = "Output", stream: bool = False): + Log a panel output. + + clear(): + Not implemented. + + print(content: str, **kwargs): + Log arbitrary content. + + format_json(json_obj: str): + Format a JSON object. + + debug(content: str, **kwargs): + Log a debug message. + + info(content: str, **kwargs): + Log an informational message. + + warning(content: str, **kwargs): + Log a warning message. + + error(content: str, **kwargs): + Log an error message. + + critical(content: str, **kwargs): + Log a critical message. + """ + + def __init__(self): + """ + Initialize the BaseOutput object. + + """ + self.logger = logging + self.log = [] + + def stop(self): + """ + Stop the output. + """ + + def update_status(self, output: str, **kwargs): + """ + Update the status of the output. + """ + if check_log(): + self.logger.info(output) + + def thinking(self, name: str): + """ + Log that a process is thinking. + """ + if check_log(): + self.logger.info(f"{name} is thinking...") + + def done(self, _all=False): + """ + Log that the process is done. + """ + + if check_log(): + self.logger.info("Done") + + def stream_print(self, item: str): + """ + Stream print. + """ + + def json_print(self, item: Dict[str, Any]): + """ + Log a JSON object. + """ + if check_log(): + self.logger.info(json.dumps(item, indent=2)) + + def panel_print(self, item: Any, title: str = "Output", stream: bool = False): + """ + Log a panel output. + + Args: + item : Any + The item to log. + title : str, optional + The title of the panel, defaults to "Output". + stream : bool, optional + """ + if not stream: + self.log.append(item) + if check_log(): + self.logger.info("-" * 20) + self.logger.info(item) + self.logger.info("-" * 20) + + def clear(self): + """ + Not implemented. + """ + + def print(self, content: str, **kwargs): + """ + Log arbitrary content. + """ + self.log.append(content) + if check_log(): + self.logger.info(content) + + def format_json(self, json_obj: str): + """ + Format a JSON object. + """ + formatted_json = json.dumps(json_obj, indent=2) + return formatted_json + + def debug(self, content: str, **kwargs): + """ + Log a debug message. + """ + if check_log(): + self.logger.debug(content, **kwargs) + + def info(self, content: str, **kwargs): + """ + Log an informational message. + """ + if check_log(): + self.logger.info(content, **kwargs) + + def warning(self, content: str, **kwargs): + """ + Log a warning message. + """ + if check_log(): + self.logger.warning(content, **kwargs) + + def error(self, content: str, **kwargs): + """ + Log an error message. + """ + if check_log(): + self.logger.error(content, **kwargs) + + def critical(self, content: str, **kwargs): + """ + Log a critical message. + """ + if check_log(): + self.logger.critical(content, **kwargs) + + +@dataclass +class AgentAction: + """Agent's action to take. + + Args: + tool: The tool to invoke. + tool_input: The input to the tool. + log: The log message. + """ + + tool: str + tool_input: Union[str, dict] + log: str + + +class AgentFinish(NamedTuple): + """Agent's return value when finishing execution. + + Args: + return_values: The return values of the agent. + log: The log message. + """ + + return_values: dict + log: str diff --git a/knowledgehub/pipelines/agents/rewoo/__init__.py b/knowledgehub/pipelines/agents/rewoo/__init__.py new file mode 100644 index 0000000..bf1ecb9 --- /dev/null +++ b/knowledgehub/pipelines/agents/rewoo/__init__.py @@ -0,0 +1,3 @@ +from .agent import RewooAgent + +__all__ = ["RewooAgent"] diff --git a/knowledgehub/pipelines/agents/rewoo/agent.py b/knowledgehub/pipelines/agents/rewoo/agent.py new file mode 100644 index 0000000..ed04e56 --- /dev/null +++ b/knowledgehub/pipelines/agents/rewoo/agent.py @@ -0,0 +1,270 @@ +import logging +import re +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from pydantic import BaseModel, create_model + +from kotaemon.llms.chats.base import ChatLLM +from kotaemon.llms.completions.base import LLM +from kotaemon.prompt.template import PromptTemplate + +from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool +from ..output.base import BaseScratchPad +from ..utils import get_plugin_response_content +from .planner import Planner +from .solver import Solver + + +class RewooAgent(BaseAgent): + """Distributive RewooAgent class inherited from BaseAgent. + Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf""" + + name: str = "RewooAgent" + type: AgentType = AgentType.rewoo + description: str = "RewooAgent for answering multi-step reasoning questions" + llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx} + prompt_template: Dict[ + str, PromptTemplate + ] = dict() # {"Planner": xxx, "Solver": xxx} + plugins: List[BaseTool] = list() + examples: Dict[str, Union[str, List[str]]] = dict() + args_schema: Optional[Type[BaseModel]] = create_model( + "ReactArgsSchema", instruction=(str, ...) + ) + + def _get_llms(self): + if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM): + return {"Planner": self.llm, "Solver": self.llm} + elif ( + isinstance(self.llm, dict) + and "Planner" in self.llm + and "Solver" in self.llm + ): + return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]} + else: + raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.") + + def _parse_plan_map( + self, planner_response: str + ) -> Tuple[Dict[str, List[str]], Dict[str, str]]: + """ + Parse planner output. It should be an n-to-n mapping from Plans to #Es. + This is because sometimes LLM cannot follow the strict output format. + Example: + #Plan1 + #E1 + #E2 + should result in: {"#Plan1": ["#E1", "#E2"]} + Or: + #Plan1 + #Plan2 + #E1 + should result in: {"#Plan1": [], "#Plan2": ["#E1"]} + This function should also return a plan map. + + Returns: + Tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map + """ + valid_chunk = [ + line + for line in planner_response.splitlines() + if line.startswith("#Plan") or line.startswith("#E") + ] + + plan_to_es: Dict[str, List[str]] = dict() + plans: Dict[str, str] = dict() + for line in valid_chunk: + if line.startswith("#Plan"): + plan = line.split(":", 1)[0].strip() + plans[plan] = line.split(":", 1)[1].strip() + plan_to_es[plan] = [] + elif line.startswith("#E"): + plan_to_es[plan].append(line.split(":", 1)[0].strip()) + + return plan_to_es, plans + + def _parse_planner_evidences( + self, planner_response: str + ) -> Tuple[Dict[str, str], List[List[str]]]: + """ + Parse planner output. This should return a mapping from #E to tool call. + It should also identify the level of each #E in dependency map. + Example: + { + "#E1": "Tool1", "#E2": "Tool2", + "#E3": "Tool3", "#E4": "Tool4" + }, [[#E1, #E2], [#E3, #E4]] + + Returns: + Tuple[dict[str, str], List[List[str]]]: + A mapping from #E to tool call and a list of levels. + """ + evidences: Dict[str, str] = dict() + dependence: Dict[str, List[str]] = dict() + for line in planner_response.splitlines(): + if line.startswith("#E") and line[2].isdigit(): + e, tool_call = line.split(":", 1) + e, tool_call = e.strip(), tool_call.strip() + if len(e) == 3: + dependence[e] = [] + evidences[e] = tool_call + for var in re.findall(r"#E\d+", tool_call): + if var in evidences: + dependence[e].append(var) + else: + evidences[e] = "No evidence found" + level = [] + while dependence: + select = [i for i in dependence if not dependence[i]] + if len(select) == 0: + raise ValueError("Circular dependency detected.") + level.append(select) + for item in select: + dependence.pop(item) + for item in dependence: + for i in select: + if i in dependence[item]: + dependence[item].remove(i) + + return evidences, level + + def _run_plugin( + self, + e: str, + planner_evidences: Dict[str, str], + worker_evidences: Dict[str, str], + output=BaseScratchPad(), + ): + """ + Run a plugin for a given evidence. + This function should also cumulate the cost and tokens. + """ + result = dict(e=e, plugin_cost=0, plugin_token=0, evidence="") + tool_call = planner_evidences[e] + if "[" not in tool_call: + result["evidence"] = tool_call + else: + tool, tool_input = tool_call.split("[", 1) + tool_input = tool_input[:-1] + # find variables in input and replace with previous evidences + for var in re.findall(r"#E\d+", tool_input): + if var in worker_evidences: + tool_input = tool_input.replace(var, worker_evidences.get(var, "")) + try: + selected_plugin = self._find_plugin(tool) + if selected_plugin is None: + raise ValueError("Invalid plugin detected") + tool_response = selected_plugin(tool_input) + # cumulate agent-as-plugin costs and tokens. + if isinstance(tool_response, AgentOutput): + result["plugin_cost"] = tool_response.cost + result["plugin_token"] = tool_response.token_usage + result["evidence"] = get_plugin_response_content(tool_response) + except ValueError: + result["evidence"] = "No evidence found." + finally: + output.panel_print( + result["evidence"], f"[green] Function Response of [blue]{tool}: " + ) + return result + + def _get_worker_evidence( + self, + planner_evidences: Dict[str, str], + evidences_level: List[List[str]], + output=BaseScratchPad(), + ) -> Any: + """ + Parallel execution of plugins in DAG for speedup. + This is one of core benefits of ReWOO agents. + + Args: + planner_evidences: A mapping from #E to tool call. + evidences_level: A list of levels of evidences. + Calculated from DAG of plugin calls. + output: Output object, defaults to BaseOutput(). + Returns: + A mapping from #E to tool call. + """ + worker_evidences: Dict[str, str] = dict() + plugin_cost, plugin_token = 0.0, 0.0 + with ThreadPoolExecutor() as pool: + for level in evidences_level: + results = [] + for e in level: + results.append( + pool.submit( + self._run_plugin, + e, + planner_evidences, + worker_evidences, + output, + ) + ) + if len(results) > 1: + output.update_status(f"Running tasks {level} in parallel.") + else: + output.update_status(f"Running task {level[0]}.") + for r in results: + resp = r.result() + plugin_cost += resp["plugin_cost"] + plugin_token += resp["plugin_token"] + worker_evidences[resp["e"]] = resp["evidence"] + output.done() + + return worker_evidences, plugin_cost, plugin_token + + def _find_plugin(self, name: str): + for p in self.plugins: + if p.name == name: + return p + + def _run_tool(self, instruction: str) -> AgentOutput: + """ + Run the agent with a given instruction. + """ + logging.info(f"Running {self.name} with instruction: {instruction}") + total_cost = 0.0 + total_token = 0 + + planner_llm = self._get_llms()["Planner"] + solver_llm = self._get_llms()["Solver"] + + planner = Planner( + model=planner_llm, + plugins=self.plugins, + prompt_template=self.prompt_template.get("Planner", None), + examples=self.examples.get("Planner", None), + ) + solver = Solver( + model=solver_llm, + prompt_template=self.prompt_template.get("Solver", None), + examples=self.examples.get("Solver", None), + ) + + # Plan + planner_output = planner(instruction) + plannner_text_output = planner_output.text[0] + plan_to_es, plans = self._parse_plan_map(plannner_text_output) + planner_evidences, evidence_level = self._parse_planner_evidences( + plannner_text_output + ) + + # Work + worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence( + planner_evidences, evidence_level + ) + worker_log = "" + for plan in plan_to_es: + worker_log += f"{plan}: {plans[plan]}\n" + for e in plan_to_es[plan]: + worker_log += f"{e}: {worker_evidences[e]}\n" + + # Solve + solver_output = solver(instruction, worker_log) + solver_output_text = solver_output.text[0] + + return AgentOutput( + output=solver_output_text, cost=total_cost, token_usage=total_token + ) diff --git a/knowledgehub/pipelines/agents/rewoo/planner.py b/knowledgehub/pipelines/agents/rewoo/planner.py new file mode 100644 index 0000000..af2ddf3 --- /dev/null +++ b/knowledgehub/pipelines/agents/rewoo/planner.py @@ -0,0 +1,82 @@ +from typing import Any, List, Optional, Union + +from kotaemon.base import BaseComponent +from kotaemon.prompt.template import PromptTemplate + +from ..base import BaseLLM, BaseTool +from ..output.base import BaseScratchPad +from .prompt import zero_shot_planner_prompt + + +class Planner(BaseComponent): + model: BaseLLM + prompt_template: Optional[PromptTemplate] = None + examples: Optional[Union[str, List[str]]] = None + plugins: List[BaseTool] + + def _compose_worker_description(self) -> str: + """ + Compose the worker prompt from the workers. + + Example: + toolname1[input]: tool1 description + toolname2[input]: tool2 description + """ + prompt = "" + try: + for worker in self.plugins: + prompt += f"{worker.name}[input]: {worker.description}\n" + except Exception: + raise ValueError("Worker must have a name and description.") + return prompt + + def _compose_fewshot_prompt(self) -> str: + if self.examples is None: + return "" + if isinstance(self.examples, str): + return self.examples + else: + return "\n\n".join([e.strip("\n") for e in self.examples]) + + def _compose_prompt(self, instruction) -> str: + """ + Compose the prompt from template, worker description, examples and instruction. + """ + worker_desctription = self._compose_worker_description() + fewshot = self._compose_fewshot_prompt() + if self.prompt_template is not None: + if "fewshot" in self.prompt_template.placeholders: + return self.prompt_template.populate( + tool_description=worker_desctription, + fewshot=fewshot, + task=instruction, + ) + else: + return self.prompt_template.populate( + tool_description=worker_desctription, task=instruction + ) + else: + if self.examples is not None: + return zero_shot_planner_prompt.populate( + tool_description=worker_desctription, + fewshot=fewshot, + task=instruction, + ) + else: + return zero_shot_planner_prompt.populate( + tool_description=worker_desctription, task=instruction + ) + + def run(self, instruction: str, output: BaseScratchPad = BaseScratchPad()) -> Any: + response = None + output.info("Running Planner") + prompt = self._compose_prompt(instruction) + output.debug(f"Prompt: {prompt}") + try: + response = self.model(prompt) + output.info("Planner run successful.") + except ValueError: + output.error("Planner failed to retrieve response from LLM") + raise ValueError("Planner failed to retrieve response from LLM") + + return response diff --git a/knowledgehub/pipelines/agents/rewoo/prompt.py b/knowledgehub/pipelines/agents/rewoo/prompt.py new file mode 100644 index 0000000..569e89f --- /dev/null +++ b/knowledgehub/pipelines/agents/rewoo/prompt.py @@ -0,0 +1,119 @@ +# flake8: noqa + +from kotaemon.prompt.template import PromptTemplate + +zero_shot_planner_prompt = PromptTemplate( + template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools. +For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step. +You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs. + +##Available Tools## +{tool_description} + +##Output Format (Replace '<...>')## +#Plan1: +#E1: [] (eg. Search[What is Python]) +#Plan2: +#E2: [] +And so on... + +##Your Task## +{task} + +##Now Begin## +""" +) + +one_shot_planner_prompt = PromptTemplate( + template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools. +For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step. +You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs. + +##Available Tools## +{tool_description} + +##Output Format## +#Plan1: +#E1: [] +#Plan2: +#E2: [] +And so on... + +##Example## +Task: What is the 4th root of 64 to the power of 3? +#Plan1: Find the 4th root of 64 +#E1: Calculator[64^(1/4)] +#Plan2: Raise the result from #Plan1 to the power of 3 +#E2: Calculator[#E1^3] + +##Your Task## +{task} + +##Now Begin## +""" +) + + +few_shot_planner_prompt = PromptTemplate( + template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools. +For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step. +You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs. + +##Available Tools## +{tool_description} + +##Output Format (Replace '<...>')## +#Plan1: +#E1: [] +#Plan2: +#E2: [] +And so on... + +##Examples## +{fewshot} + +##Your Task## +{task} + +##Now Begin## +""" +) + +zero_shot_solver_prompt = PromptTemplate( + template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful. +Your task is to briefly summarize each step, then make a short final conclusion for your task. + +##My Plans and Evidences## +{plan_evidence} + +##Example Output## +First, I , and I think <...>; Second, I <...>, and I think <...>; .... +So, . + +##Your Task## +{task} + +##Now Begin## +""" +) + +few_shot_solver_prompt = PromptTemplate( + template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful. +Your task is to briefly summarize each step, then make a short final conclusion for your task. + +##My Plans and Evidences## +{plan_evidence} + +##Example Output## +First, I , and I think <...>; Second, I <...>, and I think <...>; .... +So, . + +##Example## +{fewshot} + +##Your Task## +{task} + +##Now Begin## +""" +) diff --git a/knowledgehub/pipelines/agents/rewoo/solver.py b/knowledgehub/pipelines/agents/rewoo/solver.py new file mode 100644 index 0000000..2e6b5c9 --- /dev/null +++ b/knowledgehub/pipelines/agents/rewoo/solver.py @@ -0,0 +1,66 @@ +from typing import Any, List, Optional, Union + +from kotaemon.base import BaseComponent +from kotaemon.prompt.template import PromptTemplate + +from ..base import BaseLLM +from ..output.base import BaseScratchPad +from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt + + +class Solver(BaseComponent): + model: BaseLLM + prompt_template: Optional[PromptTemplate] = None + examples: Optional[Union[str, List[str]]] = None + + def _compose_fewshot_prompt(self) -> str: + if self.examples is None: + return "" + if isinstance(self.examples, str): + return self.examples + else: + return "\n\n".join([e.strip("\n") for e in self.examples]) + + def _compose_prompt(self, instruction, plan_evidence) -> str: + """ + Compose the prompt from template, plan&evidence, examples and instruction. + """ + fewshot = self._compose_fewshot_prompt() + if self.prompt_template is not None: + if "fewshot" in self.prompt_template.placeholders: + return self.prompt_template.populate( + plan_evidence=plan_evidence, fewshot=fewshot, task=instruction + ) + else: + return self.prompt_template.populate( + plan_evidence=plan_evidence, task=instruction + ) + else: + if self.examples is not None: + return few_shot_solver_prompt.populate( + plan_evidence=plan_evidence, fewshot=fewshot, task=instruction + ) + else: + return zero_shot_solver_prompt.populate( + plan_evidence=plan_evidence, task=instruction + ) + + def run( + self, + instruction: str, + plan_evidence: str, + output: BaseScratchPad = BaseScratchPad(), + ) -> Any: + response = None + output.info("Running Solver") + output.debug(f"Instruction: {instruction}") + output.debug(f"Plan Evidence: {plan_evidence}") + prompt = self._compose_prompt(instruction, plan_evidence) + output.debug(f"Prompt: {prompt}") + try: + response = self.model(prompt) + output.info("Solver run successful.") + except ValueError: + output.error("Solver failed to retrieve response from LLM") + + return response diff --git a/knowledgehub/pipelines/agents/utils.py b/knowledgehub/pipelines/agents/utils.py new file mode 100644 index 0000000..4845526 --- /dev/null +++ b/knowledgehub/pipelines/agents/utils.py @@ -0,0 +1,22 @@ +from .base import AgentOutput + + +def get_plugin_response_content(output) -> str: + """ + Wrapper for AgentOutput content return + """ + if isinstance(output, AgentOutput): + return output.output + else: + return str(output) + + +def calculate_cost(model_name: str, prompt_token: int, completion_token: int) -> float: + """ + Calculate the cost of a prompt and completion. + + Returns: + float: Cost of the provided model name with provided token information + """ + # TODO: to be implemented + return 0.0 diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 0000000..68baf08 --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,68 @@ +from unittest.mock import patch + +from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.pipelines.agents.rewoo import RewooAgent +from kotaemon.pipelines.tools import GoogleSearchTool, WikipediaTool + +FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!" +_openai_chat_completion_responses = [ + { + "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", + "object": "chat.completion", + "created": 1692338378, + "model": "gpt-35-turbo", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": "#Plan1: Search for Cinnamon AI company on Google\n" + "#E1: google_search[Cinnamon AI company]\n" + "#Plan2: Search for Cinnamon on Wikipedia\n" + "#E2: wikipedia[Cinnamon]", + }, + } + ], + "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, + }, + { + "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", + "object": "chat.completion", + "created": 1692338378, + "model": "gpt-35-turbo", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": FINAL_RESPONSE_TEXT, + }, + } + ], + "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, + }, +] + + +@patch( + "openai.api_resources.chat_completion.ChatCompletion.create", + side_effect=_openai_chat_completion_responses, +) +def test_rewoo_agent(openai_completion): + llm = AzureChatOpenAI( + openai_api_base="https://dummy.openai.azure.com/", + openai_api_key="dummy", + openai_api_version="2023-03-15-preview", + deployment_name="dummy-q2", + temperature=0, + ) + + plugins = [GoogleSearchTool(), WikipediaTool()] + + agent = RewooAgent(llm=llm, plugins=plugins) + + response = agent("Tell me about Cinnamon AI company") + openai_completion.assert_called() + assert response.output == FINAL_RESPONSE_TEXT