From 3cceec63ef9550ce15eaf7e48bebc444395c47b8 Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Mon, 2 Oct 2023 11:29:12 +0700 Subject: [PATCH] [AUR-431] Add ReAct Agent (#34) * add base Tool * minor update test_tool * update test dependency * update test dependency * Fix namespace conflict * update test * add base Agent Interface, add ReWoo Agent * minor update * update test * fix typo * remove unneeded print * update rewoo agent * add LLMTool * update BaseAgent type * add ReAct agent * add ReAct agent * minor update * minor update * minor update * minor update * update docstring * fix max_iteration --------- Co-authored-by: trducng --- knowledgehub/llms/chats/base.py | 22 +- knowledgehub/pipelines/agents/base.py | 6 +- .../pipelines/agents/react/__init__.py | 3 + knowledgehub/pipelines/agents/react/agent.py | 188 ++++++++++++++++++ knowledgehub/pipelines/agents/react/prompt.py | 28 +++ knowledgehub/pipelines/agents/rewoo/agent.py | 2 +- knowledgehub/pipelines/tools/__init__.py | 3 +- knowledgehub/pipelines/tools/llm.py | 36 ++++ tests/test_agent.py | 79 ++++++-- 9 files changed, 340 insertions(+), 27 deletions(-) create mode 100644 knowledgehub/pipelines/agents/react/__init__.py create mode 100644 knowledgehub/pipelines/agents/react/agent.py create mode 100644 knowledgehub/pipelines/agents/react/prompt.py create mode 100644 knowledgehub/pipelines/tools/llm.py diff --git a/knowledgehub/llms/chats/base.py b/knowledgehub/llms/chats/base.py index fd44aad..27ac5f7 100644 --- a/knowledgehub/llms/chats/base.py +++ b/knowledgehub/llms/chats/base.py @@ -34,16 +34,16 @@ class LangchainChatLLM(ChatLLM): def agent(self) -> BaseLanguageModel: return self._lc_class(**self._kwargs) - def run_raw(self, text: str) -> LLMInterface: + def run_raw(self, text: str, **kwargs) -> LLMInterface: message = HumanMessage(content=text) - return self.run_document([message]) + return self.run_document([message], **kwargs) - def run_batch_raw(self, text: List[str]) -> List[LLMInterface]: + def run_batch_raw(self, text: List[str], **kwargs) -> List[LLMInterface]: inputs = [[HumanMessage(content=each)] for each in text] - return self.run_batch_document(inputs) + return self.run_batch_document(inputs, **kwargs) - def run_document(self, text: List[Message]) -> LLMInterface: - pred = self.agent.generate([text]) # type: ignore + def run_document(self, text: List[Message], **kwargs) -> LLMInterface: + pred = self.agent.generate([text], **kwargs) # type: ignore return LLMInterface( text=[each.text for each in pred.generations[0]], completion_tokens=pred.llm_output["token_usage"]["completion_tokens"], @@ -52,20 +52,22 @@ class LangchainChatLLM(ChatLLM): logits=[], ) - def run_batch_document(self, text: List[List[Message]]) -> List[LLMInterface]: + def run_batch_document( + self, text: List[List[Message]], **kwargs + ) -> List[LLMInterface]: outputs = [] for each_text in text: - outputs.append(self.run_document(each_text)) + outputs.append(self.run_document(each_text, **kwargs)) return outputs - def is_document(self, text) -> bool: + def is_document(self, text, **kwargs) -> bool: if isinstance(text, str): return False elif isinstance(text, List) and isinstance(text[0], str): return False return True - def is_batch(self, text) -> bool: + def is_batch(self, text, **kwargs) -> bool: if isinstance(text, str): return False elif isinstance(text, List): diff --git a/knowledgehub/pipelines/agents/base.py b/knowledgehub/pipelines/agents/base.py index cb05546..a3bc603 100644 --- a/knowledgehub/pipelines/agents/base.py +++ b/knowledgehub/pipelines/agents/base.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel @@ -50,7 +50,7 @@ class AgentOutput(BaseModel): class BaseAgent(BaseTool): name: str """Name of the agent.""" - type: AgentType + agent_type: AgentType """Agent type, must be one of AgentType""" description: str """Description used to tell the model how/when/why to use the agent. @@ -59,7 +59,7 @@ class BaseAgent(BaseTool): llm: Union[BaseLLM, Dict[str, BaseLLM]] """Specify LLM to be used in the model, cam be a dict to supply different LLMs to multiple purposes in the agent""" - prompt_template: Union[PromptTemplate, Dict[str, PromptTemplate]] + prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]] """A prompt template or a dict to supply different prompt to the agent """ plugins: List[BaseTool] diff --git a/knowledgehub/pipelines/agents/react/__init__.py b/knowledgehub/pipelines/agents/react/__init__.py new file mode 100644 index 0000000..9fdf370 --- /dev/null +++ b/knowledgehub/pipelines/agents/react/__init__.py @@ -0,0 +1,3 @@ +from .agent import ReactAgent + +__all__ = ["ReactAgent"] diff --git a/knowledgehub/pipelines/agents/react/agent.py b/knowledgehub/pipelines/agents/react/agent.py new file mode 100644 index 0000000..6fc8ddf --- /dev/null +++ b/knowledgehub/pipelines/agents/react/agent.py @@ -0,0 +1,188 @@ +import logging +import re +from typing import Dict, List, Optional, Tuple, Type, Union + +from pydantic import BaseModel, create_model + +from kotaemon.prompt.template import PromptTemplate + +from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool +from ..output.base import AgentAction, AgentFinish + +FINAL_ANSWER_ACTION = "Final Answer:" + + +class ReactAgent(BaseAgent): + """ + Sequential ReactAgent class inherited from BaseAgent. + Implementing ReAct agent paradigm https://arxiv.org/pdf/2210.03629.pdf + """ + + name: str = "ReactAgent" + agent_type: AgentType = AgentType.react + description: str = "ReactAgent for answering multi-step reasoning questions" + llm: Union[BaseLLM, Dict[str, BaseLLM]] + prompt_template: Optional[PromptTemplate] = None + plugins: List[BaseTool] = list() + examples: Dict[str, Union[str, List[str]]] = dict() + args_schema: Optional[Type[BaseModel]] = create_model( + "ReactArgsSchema", instruction=(str, ...) + ) + intermediate_steps: List[Tuple[Union[AgentAction, AgentFinish], str]] = [] + """List of AgentAction and observation (tool) output""" + max_iterations = 10 + strict_decode: bool = False + + def _compose_plugin_description(self) -> str: + """ + Compose the worker prompt from the workers. + + Example: + toolname1[input]: tool1 description + toolname2[input]: tool2 description + """ + prompt = "" + try: + for plugin in self.plugins: + prompt += f"{plugin.name}[input]: {plugin.description}\n" + except Exception: + raise ValueError("Worker must have a name and description.") + return prompt + + def _construct_scratchpad( + self, intermediate_steps: List[Tuple[Union[AgentAction, AgentFinish], str]] = [] + ) -> str: + """Construct the scratchpad that lets the agent continue its thought process.""" + thoughts = "" + for action, observation in intermediate_steps: + thoughts += action.log + thoughts += f"\nObservation: {observation}\nThought:" + return thoughts + + def _parse_output(self, text: str) -> Optional[Union[AgentAction, AgentFinish]]: + """ + Parse text output from LLM for the next Action or Final Answer + Using Regex to parse "Action:\n Action Input:\n" for the next Action + Using FINAL_ANSWER_ACTION to parse Final Answer + + Args: + text[str]: input text to parse + """ + includes_answer = FINAL_ANSWER_ACTION in text + regex = ( + r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)" + ) + action_match = re.search(regex, text, re.DOTALL) + action_output: Optional[Union[AgentAction, AgentFinish]] = None + if action_match: + if includes_answer: + raise Exception( + "Parsing LLM output produced both a final answer " + f"and a parse-able action: {text}" + ) + action = action_match.group(1).strip() + action_input = action_match.group(2) + tool_input = action_input.strip(" ") + # ensure if its a well formed SQL query we don't remove any trailing " chars + if tool_input.startswith("SELECT ") is False: + tool_input = tool_input.strip('"') + + action_output = AgentAction(action, tool_input, text) + + elif includes_answer: + action_output = AgentFinish( + {"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text + ) + else: + if self.strict_decode: + raise Exception(f"Could not parse LLM output: `{text}`") + else: + action_output = AgentFinish({"output": text}, text) + + return action_output + + def _compose_prompt(self, instruction) -> str: + """ + Compose the prompt from template, worker description, examples and instruction. + """ + agent_scratchpad = self._construct_scratchpad(self.intermediate_steps) + tool_description = self._compose_plugin_description() + tool_names = ", ".join([plugin.name for plugin in self.plugins]) + if self.prompt_template is None: + from .prompt import zero_shot_react_prompt + + self.prompt_template = zero_shot_react_prompt + return self.prompt_template.populate( + instruction=instruction, + agent_scratchpad=agent_scratchpad, + tool_description=tool_description, + tool_names=tool_names, + ) + + def _format_function_map(self) -> Dict[str, BaseTool]: + """Format the function map for the open AI function API. + + Return: + Dict[str, Callable]: The function map. + """ + # Map the function name to the real function object. + function_map = {} + for plugin in self.plugins: + function_map[plugin.name] = plugin + return function_map + + def clear(self): + """ + Clear and reset the agent. + """ + self.intermediate_steps = [] + + def run(self, instruction, max_iterations=None): + """ + Run the agent with the given instruction. + + Args: + instruction: Instruction to run the agent with. + max_iterations: Maximum number of iterations + of reasoning steps, defaults to 10. + + Return: + AgentOutput object. + """ + if not max_iterations: + max_iterations = self.max_iterations + assert max_iterations > 0 + + self.clear() + logging.info(f"Running {self.name} with instruction: {instruction}") + total_cost = 0.0 + total_token = 0 + + for _ in range(max_iterations): + prompt = self._compose_prompt(instruction) + logging.info(f"Prompt: {prompt}") + response = self.llm(prompt, stop=["Observation:"]) # type: ignore + response_text = response.text[0] + logging.info(f"Response: {response_text}") + action_step = self._parse_output(response_text) + if action_step is None: + raise ValueError("Invalid action") + is_finished_chain = isinstance(action_step, AgentFinish) + if is_finished_chain: + result = "" + else: + assert isinstance(action_step, AgentAction) + action_name = action_step.tool + tool_input = action_step.tool_input + logging.info(f"Action: {action_name}") + logging.info(f"Tool Input: {tool_input}") + result = self._format_function_map()[action_name](tool_input) + logging.info(f"Result: {result}") + + self.intermediate_steps.append((action_step, result)) + if is_finished_chain: + break + + return AgentOutput( + output=response_text, cost=total_cost, token_usage=total_token + ) diff --git a/knowledgehub/pipelines/agents/react/prompt.py b/knowledgehub/pipelines/agents/react/prompt.py new file mode 100644 index 0000000..ba80b16 --- /dev/null +++ b/knowledgehub/pipelines/agents/react/prompt.py @@ -0,0 +1,28 @@ +# flake8: noqa + +from kotaemon.prompt.template import PromptTemplate + +zero_shot_react_prompt = PromptTemplate( + template="""Answer the following questions as best you can. You have access to the following tools: +{tool_description}. +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do + +Action: the action to take, should be one of [{tool_names}] + +Action Input: the input to the action + +Observation: the result of the action + +... (this Thought/Action/Action Input/Observation can repeat N times) +#Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin! After each Action Input. + +Question: {instruction} +Thought:{agent_scratchpad} + """ +) diff --git a/knowledgehub/pipelines/agents/rewoo/agent.py b/knowledgehub/pipelines/agents/rewoo/agent.py index ed04e56..1d47eb9 100644 --- a/knowledgehub/pipelines/agents/rewoo/agent.py +++ b/knowledgehub/pipelines/agents/rewoo/agent.py @@ -21,7 +21,7 @@ class RewooAgent(BaseAgent): Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf""" name: str = "RewooAgent" - type: AgentType = AgentType.rewoo + agent_type: AgentType = AgentType.rewoo description: str = "RewooAgent for answering multi-step reasoning questions" llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx} prompt_template: Dict[ diff --git a/knowledgehub/pipelines/tools/__init__.py b/knowledgehub/pipelines/tools/__init__.py index 7302ed6..3869dc9 100644 --- a/knowledgehub/pipelines/tools/__init__.py +++ b/knowledgehub/pipelines/tools/__init__.py @@ -1,5 +1,6 @@ from .base import BaseTool, ComponentTool from .google import GoogleSearchTool +from .llm import LLMTool from .wikipedia import WikipediaTool -__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool"] +__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool", "LLMTool"] diff --git a/knowledgehub/pipelines/tools/llm.py b/knowledgehub/pipelines/tools/llm.py new file mode 100644 index 0000000..63835ea --- /dev/null +++ b/knowledgehub/pipelines/tools/llm.py @@ -0,0 +1,36 @@ +from typing import AnyStr, Optional, Type, Union + +from pydantic import BaseModel, Field + +from kotaemon.llms.chats.base import ChatLLM +from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.llms.completions.base import LLM + +from .base import BaseTool, ToolException + +BaseLLM = Union[ChatLLM, LLM] + + +class LLMArgs(BaseModel): + query: str = Field(..., description="a search question or prompt") + + +class LLMTool(BaseTool): + name = "llm" + description = ( + "A pretrained LLM like yourself. Useful when you need to act with " + "general world knowledge and common sense. Prioritize it when you " + "are confident in solving the problem " + "yourself. Input can be any instruction." + ) + llm: BaseLLM = AzureChatOpenAI() + args_schema: Optional[Type[BaseModel]] = LLMArgs + + def _run_tool(self, query: AnyStr) -> str: + output = None + try: + response = self.llm(query) + except ValueError: + raise ToolException("LLM Tool call failed") + output = response.text[0] + return output diff --git a/tests/test_agent.py b/tests/test_agent.py index 68baf08..e646b6e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,11 +1,12 @@ from unittest.mock import patch from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.pipelines.agents.react import ReactAgent from kotaemon.pipelines.agents.rewoo import RewooAgent -from kotaemon.pipelines.tools import GoogleSearchTool, WikipediaTool +from kotaemon.pipelines.tools import GoogleSearchTool, LLMTool, WikipediaTool FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!" -_openai_chat_completion_responses = [ +_openai_chat_completion_responses_rewoo = [ { "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", "object": "chat.completion", @@ -17,15 +18,24 @@ _openai_chat_completion_responses = [ "finish_reason": "stop", "message": { "role": "assistant", - "content": "#Plan1: Search for Cinnamon AI company on Google\n" - "#E1: google_search[Cinnamon AI company]\n" - "#Plan2: Search for Cinnamon on Wikipedia\n" - "#E2: wikipedia[Cinnamon]", + "content": text, }, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, - }, + } + for text in [ + ( + "#Plan1: Search for Cinnamon AI company on Google\n" + "#E1: google_search[Cinnamon AI company]\n" + "#Plan2: Search for Cinnamon on Wikipedia\n" + "#E2: wikipedia[Cinnamon]\n" + ), + FINAL_RESPONSE_TEXT, + ] +] + +_openai_chat_completion_responses_react = [ { "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", "object": "chat.completion", @@ -37,18 +47,36 @@ _openai_chat_completion_responses = [ "finish_reason": "stop", "message": { "role": "assistant", - "content": FINAL_RESPONSE_TEXT, + "content": text, }, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, - }, + } + for text in [ + ( + "I don't have prior knowledge about Cinnamon AI company, " + "so I should gather information about it.\n" + "Action: wikipedia\n" + "Action Input: Cinnamon AI company\n" + ), + ( + "The information retrieved from Wikipedia is not " + "about Cinnamon AI company, but about Blue Prism, " + "a British multinational software corporation. " + "I need to try another source to gather information " + "about Cinnamon AI company.\n" + "Action: google_search\n" + "Action Input: Cinnamon AI company\n" + ), + FINAL_RESPONSE_TEXT, + ] ] @patch( "openai.api_resources.chat_completion.ChatCompletion.create", - side_effect=_openai_chat_completion_responses, + side_effect=_openai_chat_completion_responses_rewoo, ) def test_rewoo_agent(openai_completion): llm = AzureChatOpenAI( @@ -58,11 +86,38 @@ def test_rewoo_agent(openai_completion): deployment_name="dummy-q2", temperature=0, ) - - plugins = [GoogleSearchTool(), WikipediaTool()] + plugins = [ + GoogleSearchTool(), + WikipediaTool(), + LLMTool(llm=llm), + ] agent = RewooAgent(llm=llm, plugins=plugins) response = agent("Tell me about Cinnamon AI company") openai_completion.assert_called() assert response.output == FINAL_RESPONSE_TEXT + + +@patch( + "openai.api_resources.chat_completion.ChatCompletion.create", + side_effect=_openai_chat_completion_responses_react, +) +def test_react_agent(openai_completion): + llm = AzureChatOpenAI( + openai_api_base="https://dummy.openai.azure.com/", + openai_api_key="dummy", + openai_api_version="2023-03-15-preview", + deployment_name="dummy-q2", + temperature=0, + ) + plugins = [ + GoogleSearchTool(), + WikipediaTool(), + LLMTool(llm=llm), + ] + agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4) + + response = agent("Tell me about Cinnamon AI company") + openai_completion.assert_called() + assert response.output == FINAL_RESPONSE_TEXT