Refactor agents and tools (#91)
* Move tools to agents * Move agents to dedicate place * Remove subclassing BaseAgent from BaseTool
This commit is contained in:
committed by
GitHub
parent
4256030b4f
commit
8e3a1d193f
6
knowledgehub/agents/__init__.py
Normal file
6
knowledgehub/agents/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .base import AgentType, BaseAgent
|
||||
from .langchain import LangchainAgent
|
||||
from .react.agent import ReactAgent
|
||||
from .rewoo.agent import RewooAgent
|
||||
|
||||
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"]
|
68
knowledgehub/agents/base.py
Normal file
68
knowledgehub/agents/base.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.base import ChatLLM
|
||||
from kotaemon.llms.completions.base import LLM
|
||||
|
||||
from .tools import BaseTool
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
class AgentType(Enum):
|
||||
"""
|
||||
Enumerated type for agent types.
|
||||
"""
|
||||
|
||||
openai = "openai"
|
||||
openai_multi = "openai_multi"
|
||||
openai_tool = "openai_tool"
|
||||
self_ask = "self_ask"
|
||||
react = "react"
|
||||
rewoo = "rewoo"
|
||||
vanilla = "vanilla"
|
||||
|
||||
@staticmethod
|
||||
def get_agent_class(_type: "AgentType"):
|
||||
"""
|
||||
Get agent class from agent type.
|
||||
:param _type: agent type
|
||||
:return: agent class
|
||||
"""
|
||||
if _type == AgentType.rewoo:
|
||||
from .rewoo.agent import RewooAgent
|
||||
|
||||
return RewooAgent
|
||||
else:
|
||||
raise ValueError(f"Unknown agent type: {_type}")
|
||||
|
||||
|
||||
class BaseAgent(BaseComponent):
|
||||
"""Define base agent interface"""
|
||||
|
||||
name: str = Param(help="Name of the agent.")
|
||||
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
|
||||
description: str = Param(
|
||||
help="Description used to tell the model how/when/why to use the agent. "
|
||||
"You can provide few-shot examples as a part of the description. This will be "
|
||||
"input to the prompt of LLM."
|
||||
)
|
||||
llm: Union[BaseLLM, dict[str, BaseLLM]] = Node(
|
||||
help="Specify LLM to be used in the model, cam be a dict to supply different "
|
||||
"LLMs to multiple purposes in the agent"
|
||||
)
|
||||
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
|
||||
help="A prompt template or a dict to supply different prompt to the agent"
|
||||
)
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [],
|
||||
help="List of plugins / tools to be used in the agent",
|
||||
)
|
||||
|
||||
def add_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Helper method to add tools and update agent state if needed"""
|
||||
self.plugins.extend(tools)
|
81
knowledgehub/agents/langchain.py
Normal file
81
knowledgehub/agents/langchain.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.agents import AgentType as LCAgentType
|
||||
from langchain.agents import initialize_agent
|
||||
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.llms.chats.base import ChatLLM
|
||||
from kotaemon.llms.completions.base import LLM
|
||||
|
||||
from .base import AgentType, BaseAgent
|
||||
|
||||
|
||||
class LangchainAgent(BaseAgent):
|
||||
"""Wrapper for Langchain Agent"""
|
||||
|
||||
name: str = "LangchainAgent"
|
||||
agent_type: AgentType
|
||||
description: str = "LangchainAgent for answering multi-step reasoning questions"
|
||||
AGENT_TYPE_MAP = {
|
||||
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
|
||||
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
|
||||
AgentType.react: LCAgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
AgentType.self_ask: LCAgentType.SELF_ASK_WITH_SEARCH,
|
||||
}
|
||||
agent: Optional[LCAgentExecutor] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.agent_type not in self.AGENT_TYPE_MAP:
|
||||
raise NotImplementedError(
|
||||
f"AgentType {self.agent_type } not supported by Langchain wrapper"
|
||||
)
|
||||
self.update_agent_tools()
|
||||
|
||||
def update_agent_tools(self):
|
||||
assert isinstance(self.llm, (ChatLLM, LLM))
|
||||
langchain_plugins = [tool.to_langchain_format() for tool in self.plugins]
|
||||
|
||||
# a fix for search_doc tool name:
|
||||
# use "Intermediate Answer" for self-ask agent
|
||||
found_search_tool = False
|
||||
if self.agent_type == AgentType.self_ask:
|
||||
for plugin in langchain_plugins:
|
||||
if plugin.name == "search_doc":
|
||||
plugin.name = "Intermediate Answer"
|
||||
langchain_plugins = [plugin]
|
||||
found_search_tool = True
|
||||
break
|
||||
|
||||
if self.agent_type != AgentType.self_ask or found_search_tool:
|
||||
# reinit Langchain AgentExecutor
|
||||
self.agent = initialize_agent(
|
||||
langchain_plugins,
|
||||
self.llm.agent,
|
||||
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
||||
handle_parsing_errors=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def add_tools(self, tools: List[BaseTool]) -> None:
|
||||
super().add_tools(tools)
|
||||
self.update_agent_tools()
|
||||
return
|
||||
|
||||
def run(self, instruction: str) -> Document:
|
||||
assert (
|
||||
self.agent is not None
|
||||
), "Lanchain AgentExecutor is not correclty initialized"
|
||||
# Langchain AgentExecutor call
|
||||
output = self.agent(instruction)["output"]
|
||||
return Document(
|
||||
text=output,
|
||||
metadata={
|
||||
"agent": "langchain",
|
||||
"cost": 0.0,
|
||||
"usage": 0,
|
||||
},
|
||||
)
|
0
knowledgehub/agents/output/__init__.py
Normal file
0
knowledgehub/agents/output/__init__.py
Normal file
219
knowledgehub/agents/output/base.py
Normal file
219
knowledgehub/agents/output/base.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, NamedTuple, Union
|
||||
|
||||
|
||||
def check_log():
|
||||
"""
|
||||
Checks if logging has been enabled.
|
||||
:return: True if logging has been enabled, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return os.environ.get("LOG_PATH", None) is not None
|
||||
|
||||
|
||||
class BaseScratchPad:
|
||||
"""
|
||||
Base class for output handlers.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
logger : logging.Logger
|
||||
The logger object to log messages.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
stop():
|
||||
Stop the output.
|
||||
|
||||
update_status(output: str, **kwargs):
|
||||
Update the status of the output.
|
||||
|
||||
thinking(name: str):
|
||||
Log that a process is thinking.
|
||||
|
||||
done(_all=False):
|
||||
Log that the process is done.
|
||||
|
||||
stream_print(item: str):
|
||||
Not implemented.
|
||||
|
||||
json_print(item: Dict[str, Any]):
|
||||
Log a JSON object.
|
||||
|
||||
panel_print(item: Any, title: str = "Output", stream: bool = False):
|
||||
Log a panel output.
|
||||
|
||||
clear():
|
||||
Not implemented.
|
||||
|
||||
print(content: str, **kwargs):
|
||||
Log arbitrary content.
|
||||
|
||||
format_json(json_obj: str):
|
||||
Format a JSON object.
|
||||
|
||||
debug(content: str, **kwargs):
|
||||
Log a debug message.
|
||||
|
||||
info(content: str, **kwargs):
|
||||
Log an informational message.
|
||||
|
||||
warning(content: str, **kwargs):
|
||||
Log a warning message.
|
||||
|
||||
error(content: str, **kwargs):
|
||||
Log an error message.
|
||||
|
||||
critical(content: str, **kwargs):
|
||||
Log a critical message.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the BaseOutput object.
|
||||
|
||||
"""
|
||||
self.logger = logging
|
||||
self.log = []
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the output.
|
||||
"""
|
||||
|
||||
def update_status(self, output: str, **kwargs):
|
||||
"""
|
||||
Update the status of the output.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(output)
|
||||
|
||||
def thinking(self, name: str):
|
||||
"""
|
||||
Log that a process is thinking.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(f"{name} is thinking...")
|
||||
|
||||
def done(self, _all=False):
|
||||
"""
|
||||
Log that the process is done.
|
||||
"""
|
||||
|
||||
if check_log():
|
||||
self.logger.info("Done")
|
||||
|
||||
def stream_print(self, item: str):
|
||||
"""
|
||||
Stream print.
|
||||
"""
|
||||
|
||||
def json_print(self, item: Dict[str, Any]):
|
||||
"""
|
||||
Log a JSON object.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(json.dumps(item, indent=2))
|
||||
|
||||
def panel_print(self, item: Any, title: str = "Output", stream: bool = False):
|
||||
"""
|
||||
Log a panel output.
|
||||
|
||||
Args:
|
||||
item : Any
|
||||
The item to log.
|
||||
title : str, optional
|
||||
The title of the panel, defaults to "Output".
|
||||
stream : bool, optional
|
||||
"""
|
||||
if not stream:
|
||||
self.log.append(item)
|
||||
if check_log():
|
||||
self.logger.info("-" * 20)
|
||||
self.logger.info(item)
|
||||
self.logger.info("-" * 20)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Not implemented.
|
||||
"""
|
||||
|
||||
def print(self, content: str, **kwargs):
|
||||
"""
|
||||
Log arbitrary content.
|
||||
"""
|
||||
self.log.append(content)
|
||||
if check_log():
|
||||
self.logger.info(content)
|
||||
|
||||
def format_json(self, json_obj: str):
|
||||
"""
|
||||
Format a JSON object.
|
||||
"""
|
||||
formatted_json = json.dumps(json_obj, indent=2)
|
||||
return formatted_json
|
||||
|
||||
def debug(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a debug message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.debug(content, **kwargs)
|
||||
|
||||
def info(self, content: str, **kwargs):
|
||||
"""
|
||||
Log an informational message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(content, **kwargs)
|
||||
|
||||
def warning(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a warning message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.warning(content, **kwargs)
|
||||
|
||||
def error(self, content: str, **kwargs):
|
||||
"""
|
||||
Log an error message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.error(content, **kwargs)
|
||||
|
||||
def critical(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a critical message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.critical(content, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""Agent's action to take.
|
||||
|
||||
Args:
|
||||
tool: The tool to invoke.
|
||||
tool_input: The input to the tool.
|
||||
log: The log message.
|
||||
"""
|
||||
|
||||
tool: str
|
||||
tool_input: Union[str, dict]
|
||||
log: str
|
||||
|
||||
|
||||
class AgentFinish(NamedTuple):
|
||||
"""Agent's return value when finishing execution.
|
||||
|
||||
Args:
|
||||
return_values: The return values of the agent.
|
||||
log: The log message.
|
||||
"""
|
||||
|
||||
return_values: dict
|
||||
log: str
|
3
knowledgehub/agents/react/__init__.py
Normal file
3
knowledgehub/agents/react/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent import ReactAgent
|
||||
|
||||
__all__ = ["ReactAgent"]
|
197
knowledgehub/agents/react/agent.py
Normal file
197
knowledgehub/agents/react/agent.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from theflow import Param
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
||||
from ..output.base import AgentAction, AgentFinish
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
|
||||
class ReactAgent(BaseAgent):
|
||||
"""
|
||||
Sequential ReactAgent class inherited from BaseAgent.
|
||||
Implementing ReAct agent paradigm https://arxiv.org/pdf/2210.03629.pdf
|
||||
"""
|
||||
|
||||
name: str = "ReactAgent"
|
||||
agent_type: AgentType = AgentType.react
|
||||
description: str = "ReactAgent for answering multi-step reasoning questions"
|
||||
llm: BaseLLM | dict[str, BaseLLM]
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
||||
)
|
||||
examples: dict[str, str | list[str]] = Param(
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent. "
|
||||
)
|
||||
intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = Param(
|
||||
default_callback=lambda _: [],
|
||||
help="List of AgentAction and observation (tool) output",
|
||||
)
|
||||
max_iterations = 10
|
||||
strict_decode: bool = False
|
||||
|
||||
def _compose_plugin_description(self) -> str:
|
||||
"""
|
||||
Compose the worker prompt from the workers.
|
||||
|
||||
Example:
|
||||
toolname1[input]: tool1 description
|
||||
toolname2[input]: tool2 description
|
||||
"""
|
||||
prompt = ""
|
||||
try:
|
||||
for plugin in self.plugins:
|
||||
prompt += f"{plugin.name}[input]: {plugin.description}\n"
|
||||
except Exception:
|
||||
raise ValueError("Worker must have a name and description.")
|
||||
return prompt
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = []
|
||||
) -> str:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts = ""
|
||||
for action, observation in intermediate_steps:
|
||||
thoughts += action.log
|
||||
thoughts += f"\nObservation: {observation}\nThought:"
|
||||
return thoughts
|
||||
|
||||
def _parse_output(self, text: str) -> Optional[AgentAction | AgentFinish]:
|
||||
"""
|
||||
Parse text output from LLM for the next Action or Final Answer
|
||||
Using Regex to parse "Action:\n Action Input:\n" for the next Action
|
||||
Using FINAL_ANSWER_ACTION to parse Final Answer
|
||||
|
||||
Args:
|
||||
text[str]: input text to parse
|
||||
"""
|
||||
includes_answer = FINAL_ANSWER_ACTION in text
|
||||
regex = (
|
||||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
action_output: Optional[AgentAction | AgentFinish] = None
|
||||
if action_match:
|
||||
if includes_answer:
|
||||
raise Exception(
|
||||
"Parsing LLM output produced both a final answer "
|
||||
f"and a parse-able action: {text}"
|
||||
)
|
||||
action = action_match.group(1).strip()
|
||||
action_input = action_match.group(2)
|
||||
tool_input = action_input.strip(" ")
|
||||
# ensure if its a well formed SQL query we don't remove any trailing " chars
|
||||
if tool_input.startswith("SELECT ") is False:
|
||||
tool_input = tool_input.strip('"')
|
||||
|
||||
action_output = AgentAction(action, tool_input, text)
|
||||
|
||||
elif includes_answer:
|
||||
action_output = AgentFinish(
|
||||
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
|
||||
)
|
||||
else:
|
||||
if self.strict_decode:
|
||||
raise Exception(f"Could not parse LLM output: `{text}`")
|
||||
else:
|
||||
action_output = AgentFinish({"output": text}, text)
|
||||
|
||||
return action_output
|
||||
|
||||
def _compose_prompt(self, instruction) -> str:
|
||||
"""
|
||||
Compose the prompt from template, worker description, examples and instruction.
|
||||
"""
|
||||
agent_scratchpad = self._construct_scratchpad(self.intermediate_steps)
|
||||
tool_description = self._compose_plugin_description()
|
||||
tool_names = ", ".join([plugin.name for plugin in self.plugins])
|
||||
if self.prompt_template is None:
|
||||
from .prompt import zero_shot_react_prompt
|
||||
|
||||
self.prompt_template = zero_shot_react_prompt
|
||||
return self.prompt_template.populate(
|
||||
instruction=instruction,
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
tool_description=tool_description,
|
||||
tool_names=tool_names,
|
||||
)
|
||||
|
||||
def _format_function_map(self) -> dict[str, BaseTool]:
|
||||
"""Format the function map for the open AI function API.
|
||||
|
||||
Return:
|
||||
Dict[str, Callable]: The function map.
|
||||
"""
|
||||
# Map the function name to the real function object.
|
||||
function_map = {}
|
||||
for plugin in self.plugins:
|
||||
function_map[plugin.name] = plugin
|
||||
return function_map
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clear and reset the agent.
|
||||
"""
|
||||
self.intermediate_steps = []
|
||||
|
||||
def run(self, instruction, max_iterations=None):
|
||||
"""
|
||||
Run the agent with the given instruction.
|
||||
|
||||
Args:
|
||||
instruction: Instruction to run the agent with.
|
||||
max_iterations: Maximum number of iterations
|
||||
of reasoning steps, defaults to 10.
|
||||
|
||||
Return:
|
||||
AgentOutput object.
|
||||
"""
|
||||
if not max_iterations:
|
||||
max_iterations = self.max_iterations
|
||||
assert max_iterations > 0
|
||||
|
||||
self.clear()
|
||||
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
|
||||
for _ in range(max_iterations):
|
||||
prompt = self._compose_prompt(instruction)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
|
||||
response_text = response.text
|
||||
logging.info(f"Response: {response_text}")
|
||||
action_step = self._parse_output(response_text)
|
||||
if action_step is None:
|
||||
raise ValueError("Invalid action")
|
||||
is_finished_chain = isinstance(action_step, AgentFinish)
|
||||
if is_finished_chain:
|
||||
result = ""
|
||||
else:
|
||||
assert isinstance(action_step, AgentAction)
|
||||
action_name = action_step.tool
|
||||
tool_input = action_step.tool_input
|
||||
logging.info(f"Action: {action_name}")
|
||||
logging.info(f"Tool Input: {tool_input}")
|
||||
result = self._format_function_map()[action_name](tool_input)
|
||||
logging.info(f"Result: {result}")
|
||||
|
||||
self.intermediate_steps.append((action_step, result))
|
||||
if is_finished_chain:
|
||||
break
|
||||
|
||||
return Document(
|
||||
text=response_text,
|
||||
metadata={
|
||||
"agent": "react",
|
||||
"cost": total_cost,
|
||||
"usage": total_token,
|
||||
},
|
||||
)
|
28
knowledgehub/agents/react/prompt.py
Normal file
28
knowledgehub/agents/react/prompt.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# flake8: noqa
|
||||
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
zero_shot_react_prompt = PromptTemplate(
|
||||
template="""Answer the following questions as best you can. You have access to the following tools:
|
||||
{tool_description}
|
||||
Use the following format:
|
||||
|
||||
Question: the input question you must answer
|
||||
Thought: you should always think about what to do
|
||||
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
|
||||
Action Input: the input to the action
|
||||
|
||||
Observation: the result of the action
|
||||
|
||||
... (this Thought/Action/Action Input/Observation can repeat N times)
|
||||
#Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin! After each Action Input.
|
||||
|
||||
Question: {instruction}
|
||||
Thought:{agent_scratchpad}
|
||||
"""
|
||||
)
|
3
knowledgehub/agents/rewoo/__init__.py
Normal file
3
knowledgehub/agents/rewoo/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .agent import RewooAgent
|
||||
|
||||
__all__ = ["RewooAgent"]
|
279
knowledgehub/agents/rewoo/agent.py
Normal file
279
knowledgehub/agents/rewoo/agent.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import logging
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from theflow import Param
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
||||
from kotaemon.pipelines.citation import CitationPipeline
|
||||
|
||||
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
||||
from ..output.base import BaseScratchPad
|
||||
from ..utils import get_plugin_response_content
|
||||
from .planner import Planner
|
||||
from .solver import Solver
|
||||
|
||||
|
||||
class RewooAgent(BaseAgent):
|
||||
"""Distributive RewooAgent class inherited from BaseAgent.
|
||||
Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf"""
|
||||
|
||||
name: str = "RewooAgent"
|
||||
agent_type: AgentType = AgentType.rewoo
|
||||
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx}
|
||||
prompt_template: dict[str, PromptTemplate] = Param(
|
||||
default_callback=lambda _: {},
|
||||
help="A dict to supply different prompt to the agent.",
|
||||
)
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="A list of plugins to be used in the model."
|
||||
)
|
||||
examples: dict[str, str | list[str]] = Param(
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||
)
|
||||
|
||||
def _get_llms(self):
|
||||
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM):
|
||||
return {"Planner": self.llm, "Solver": self.llm}
|
||||
elif (
|
||||
isinstance(self.llm, dict)
|
||||
and "Planner" in self.llm
|
||||
and "Solver" in self.llm
|
||||
):
|
||||
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
|
||||
else:
|
||||
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
|
||||
|
||||
def _parse_plan_map(
|
||||
self, planner_response: str
|
||||
) -> tuple[dict[str, list[str]], dict[str, str]]:
|
||||
"""
|
||||
Parse planner output. It should be an n-to-n mapping from Plans to #Es.
|
||||
This is because sometimes LLM cannot follow the strict output format.
|
||||
Example:
|
||||
#Plan1
|
||||
#E1
|
||||
#E2
|
||||
should result in: {"#Plan1": ["#E1", "#E2"]}
|
||||
Or:
|
||||
#Plan1
|
||||
#Plan2
|
||||
#E1
|
||||
should result in: {"#Plan1": [], "#Plan2": ["#E1"]}
|
||||
This function should also return a plan map.
|
||||
|
||||
Returns:
|
||||
tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
|
||||
"""
|
||||
valid_chunk = [
|
||||
line
|
||||
for line in planner_response.splitlines()
|
||||
if line.startswith("#Plan") or line.startswith("#E")
|
||||
]
|
||||
|
||||
plan_to_es: dict[str, list[str]] = dict()
|
||||
plans: dict[str, str] = dict()
|
||||
for line in valid_chunk:
|
||||
if line.startswith("#Plan"):
|
||||
plan = line.split(":", 1)[0].strip()
|
||||
plans[plan] = line.split(":", 1)[1].strip()
|
||||
plan_to_es[plan] = []
|
||||
elif line.startswith("#E"):
|
||||
plan_to_es[plan].append(line.split(":", 1)[0].strip())
|
||||
|
||||
return plan_to_es, plans
|
||||
|
||||
def _parse_planner_evidences(
|
||||
self, planner_response: str
|
||||
) -> tuple[dict[str, str], list[list[str]]]:
|
||||
"""
|
||||
Parse planner output. This should return a mapping from #E to tool call.
|
||||
It should also identify the level of each #E in dependency map.
|
||||
Example:
|
||||
{
|
||||
"#E1": "Tool1", "#E2": "Tool2",
|
||||
"#E3": "Tool3", "#E4": "Tool4"
|
||||
}, [[#E1, #E2], [#E3, #E4]]
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, str], List[List[str]]]:
|
||||
A mapping from #E to tool call and a list of levels.
|
||||
"""
|
||||
evidences: dict[str, str] = dict()
|
||||
dependence: dict[str, list[str]] = dict()
|
||||
for line in planner_response.splitlines():
|
||||
if line.startswith("#E") and line[2].isdigit():
|
||||
e, tool_call = line.split(":", 1)
|
||||
e, tool_call = e.strip(), tool_call.strip()
|
||||
if len(e) == 3:
|
||||
dependence[e] = []
|
||||
evidences[e] = tool_call
|
||||
for var in re.findall(r"#E\d+", tool_call):
|
||||
if var in evidences:
|
||||
dependence[e].append(var)
|
||||
else:
|
||||
evidences[e] = "No evidence found"
|
||||
level = []
|
||||
while dependence:
|
||||
select = [i for i in dependence if not dependence[i]]
|
||||
if len(select) == 0:
|
||||
raise ValueError("Circular dependency detected.")
|
||||
level.append(select)
|
||||
for item in select:
|
||||
dependence.pop(item)
|
||||
for item in dependence:
|
||||
for i in select:
|
||||
if i in dependence[item]:
|
||||
dependence[item].remove(i)
|
||||
|
||||
return evidences, level
|
||||
|
||||
def _run_plugin(
|
||||
self,
|
||||
e: str,
|
||||
planner_evidences: dict[str, str],
|
||||
worker_evidences: dict[str, str],
|
||||
output=BaseScratchPad(),
|
||||
):
|
||||
"""
|
||||
Run a plugin for a given evidence.
|
||||
This function should also cumulate the cost and tokens.
|
||||
"""
|
||||
result = dict(e=e, plugin_cost=0, plugin_token=0, evidence="")
|
||||
tool_call = planner_evidences[e]
|
||||
if "[" not in tool_call:
|
||||
result["evidence"] = tool_call
|
||||
else:
|
||||
tool, tool_input = tool_call.split("[", 1)
|
||||
tool_input = tool_input[:-1]
|
||||
# find variables in input and replace with previous evidences
|
||||
for var in re.findall(r"#E\d+", tool_input):
|
||||
if var in worker_evidences:
|
||||
tool_input = tool_input.replace(var, worker_evidences.get(var, ""))
|
||||
try:
|
||||
selected_plugin = self._find_plugin(tool)
|
||||
if selected_plugin is None:
|
||||
raise ValueError("Invalid plugin detected")
|
||||
tool_response = selected_plugin(tool_input)
|
||||
result["evidence"] = get_plugin_response_content(tool_response)
|
||||
except ValueError:
|
||||
result["evidence"] = "No evidence found."
|
||||
finally:
|
||||
output.panel_print(
|
||||
result["evidence"], f"[green] Function Response of [blue]{tool}: "
|
||||
)
|
||||
return result
|
||||
|
||||
def _get_worker_evidence(
|
||||
self,
|
||||
planner_evidences: dict[str, str],
|
||||
evidences_level: list[list[str]],
|
||||
output=BaseScratchPad(),
|
||||
) -> Any:
|
||||
"""
|
||||
Parallel execution of plugins in DAG for speedup.
|
||||
This is one of core benefits of ReWOO agents.
|
||||
|
||||
Args:
|
||||
planner_evidences: A mapping from #E to tool call.
|
||||
evidences_level: A list of levels of evidences.
|
||||
Calculated from DAG of plugin calls.
|
||||
output: Output object, defaults to BaseOutput().
|
||||
Returns:
|
||||
A mapping from #E to tool call.
|
||||
"""
|
||||
worker_evidences: dict[str, str] = dict()
|
||||
plugin_cost, plugin_token = 0.0, 0.0
|
||||
with ThreadPoolExecutor() as pool:
|
||||
for level in evidences_level:
|
||||
results = []
|
||||
for e in level:
|
||||
results.append(
|
||||
pool.submit(
|
||||
self._run_plugin,
|
||||
e,
|
||||
planner_evidences,
|
||||
worker_evidences,
|
||||
output,
|
||||
)
|
||||
)
|
||||
if len(results) > 1:
|
||||
output.update_status(f"Running tasks {level} in parallel.")
|
||||
else:
|
||||
output.update_status(f"Running task {level[0]}.")
|
||||
for r in results:
|
||||
resp = r.result()
|
||||
plugin_cost += resp["plugin_cost"]
|
||||
plugin_token += resp["plugin_token"]
|
||||
worker_evidences[resp["e"]] = resp["evidence"]
|
||||
output.done()
|
||||
|
||||
return worker_evidences, plugin_cost, plugin_token
|
||||
|
||||
def _find_plugin(self, name: str):
|
||||
for p in self.plugins:
|
||||
if p.name == name:
|
||||
return p
|
||||
|
||||
def run(self, instruction: str, use_citation: bool = False) -> Document:
|
||||
"""
|
||||
Run the agent with a given instruction.
|
||||
"""
|
||||
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
|
||||
planner_llm = self._get_llms()["Planner"]
|
||||
solver_llm = self._get_llms()["Solver"]
|
||||
|
||||
planner = Planner(
|
||||
model=planner_llm,
|
||||
plugins=self.plugins,
|
||||
prompt_template=self.prompt_template.get("Planner", None),
|
||||
examples=self.examples.get("Planner", None),
|
||||
)
|
||||
solver = Solver(
|
||||
model=solver_llm,
|
||||
prompt_template=self.prompt_template.get("Solver", None),
|
||||
examples=self.examples.get("Solver", None),
|
||||
)
|
||||
|
||||
# Plan
|
||||
planner_output = planner(instruction)
|
||||
plannner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
|
||||
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||
plannner_text_output
|
||||
)
|
||||
|
||||
# Work
|
||||
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
|
||||
planner_evidences, evidence_level
|
||||
)
|
||||
worker_log = ""
|
||||
for plan in plan_to_es:
|
||||
worker_log += f"{plan}: {plans[plan]}\n"
|
||||
for e in plan_to_es[plan]:
|
||||
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||
|
||||
# Solve
|
||||
solver_output = solver(instruction, worker_log)
|
||||
solver_output_text = solver_output.text
|
||||
if use_citation:
|
||||
citation_pipeline = CitationPipeline(llm=solver_llm)
|
||||
citation = citation_pipeline(context=worker_log, question=instruction)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
return Document(
|
||||
text=solver_output_text,
|
||||
metadata={
|
||||
"agent": "react",
|
||||
"cost": total_cost,
|
||||
"usage": total_token,
|
||||
"citation": citation,
|
||||
},
|
||||
)
|
83
knowledgehub/agents/rewoo/planner.py
Normal file
83
knowledgehub/agents/rewoo/planner.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from ..base import BaseLLM, BaseTool
|
||||
from ..output.base import BaseScratchPad
|
||||
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
|
||||
|
||||
|
||||
class Planner(BaseComponent):
|
||||
model: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
examples: Optional[Union[str, List[str]]] = None
|
||||
plugins: List[BaseTool]
|
||||
|
||||
def _compose_worker_description(self) -> str:
|
||||
"""
|
||||
Compose the worker prompt from the workers.
|
||||
|
||||
Example:
|
||||
toolname1[input]: tool1 description
|
||||
toolname2[input]: tool2 description
|
||||
"""
|
||||
prompt = ""
|
||||
try:
|
||||
for worker in self.plugins:
|
||||
prompt += f"{worker.name}[input]: {worker.description}\n"
|
||||
except Exception:
|
||||
raise ValueError("Worker must have a name and description.")
|
||||
return prompt
|
||||
|
||||
def _compose_fewshot_prompt(self) -> str:
|
||||
if self.examples is None:
|
||||
return ""
|
||||
if isinstance(self.examples, str):
|
||||
return self.examples
|
||||
else:
|
||||
return "\n\n".join([e.strip("\n") for e in self.examples])
|
||||
|
||||
def _compose_prompt(self, instruction) -> str:
|
||||
"""
|
||||
Compose the prompt from template, worker description, examples and instruction.
|
||||
"""
|
||||
worker_desctription = self._compose_worker_description()
|
||||
fewshot = self._compose_fewshot_prompt()
|
||||
if self.prompt_template is not None:
|
||||
if "fewshot" in self.prompt_template.placeholders:
|
||||
return self.prompt_template.populate(
|
||||
tool_description=worker_desctription,
|
||||
fewshot=fewshot,
|
||||
task=instruction,
|
||||
)
|
||||
else:
|
||||
return self.prompt_template.populate(
|
||||
tool_description=worker_desctription, task=instruction
|
||||
)
|
||||
else:
|
||||
if self.examples is not None:
|
||||
return few_shot_planner_prompt.populate(
|
||||
tool_description=worker_desctription,
|
||||
fewshot=fewshot,
|
||||
task=instruction,
|
||||
)
|
||||
else:
|
||||
return zero_shot_planner_prompt.populate(
|
||||
tool_description=worker_desctription, task=instruction
|
||||
)
|
||||
|
||||
def run(self, instruction: str, output: BaseScratchPad = BaseScratchPad()) -> Any:
|
||||
response = None
|
||||
output.info("Running Planner")
|
||||
prompt = self._compose_prompt(instruction)
|
||||
output.debug(f"Prompt: {prompt}")
|
||||
try:
|
||||
response = self.model(prompt)
|
||||
self.log_progress(".planner", response=response)
|
||||
output.info("Planner run successful.")
|
||||
except ValueError as e:
|
||||
output.error("Planner failed to retrieve response from LLM")
|
||||
raise ValueError("Planner failed to retrieve response from LLM") from e
|
||||
|
||||
return response
|
119
knowledgehub/agents/rewoo/prompt.py
Normal file
119
knowledgehub/agents/rewoo/prompt.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# flake8: noqa
|
||||
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
zero_shot_planner_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||
|
||||
##Available Tools##
|
||||
{tool_description}
|
||||
|
||||
##Output Format (Replace '<...>')##
|
||||
#Plan1: <describe your plan here>
|
||||
#E1: <toolname>[<input here>] (eg. Search[What is Python])
|
||||
#Plan2: <describe next plan>
|
||||
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
|
||||
And so on...
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
||||
|
||||
one_shot_planner_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||
|
||||
##Available Tools##
|
||||
{tool_description}
|
||||
|
||||
##Output Format##
|
||||
#Plan1: <describe your plan here>
|
||||
#E1: <toolname>[<input here>]
|
||||
#Plan2: <describe next plan>
|
||||
#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]
|
||||
And so on...
|
||||
|
||||
##Example##
|
||||
Task: What is the 4th root of 64 to the power of 3?
|
||||
#Plan1: Find the 4th root of 64
|
||||
#E1: Calculator[64^(1/4)]
|
||||
#Plan2: Raise the result from #Plan1 to the power of 3
|
||||
#E2: Calculator[#E1^3]
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
few_shot_planner_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who makes step-by-step plans to solve a problem under the help of external tools.
|
||||
For each step, make one plan followed by one tool-call, which will be executed later to retrieve evidence for that step.
|
||||
You should store each evidence into a distinct variable #E1, #E2, #E3 ... that can be referred to in later tool-call inputs.
|
||||
|
||||
##Available Tools##
|
||||
{tool_description}
|
||||
|
||||
##Output Format (Replace '<...>')##
|
||||
#Plan1: <describe your plan here>
|
||||
#E1: <toolname>[<input>]
|
||||
#Plan2: <describe next plan>
|
||||
#E2: <toolname>[<input, you can use #E1 to represent its expected output>]
|
||||
And so on...
|
||||
|
||||
##Examples##
|
||||
{fewshot}
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
||||
|
||||
zero_shot_solver_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
|
||||
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
||||
|
||||
##My Plans and Evidences##
|
||||
{plan_evidence}
|
||||
|
||||
##Example Output##
|
||||
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
|
||||
So, <your conclusion>.
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
||||
|
||||
few_shot_solver_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
|
||||
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
||||
|
||||
##My Plans and Evidences##
|
||||
{plan_evidence}
|
||||
|
||||
##Example Output##
|
||||
First, I <did something> , and I think <...>; Second, I <...>, and I think <...>; ....
|
||||
So, <your conclusion>.
|
||||
|
||||
##Example##
|
||||
{fewshot}
|
||||
|
||||
##Your Task##
|
||||
{task}
|
||||
|
||||
##Now Begin##
|
||||
"""
|
||||
)
|
66
knowledgehub/agents/rewoo/solver.py
Normal file
66
knowledgehub/agents/rewoo/solver.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from ..base import BaseLLM
|
||||
from ..output.base import BaseScratchPad
|
||||
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
|
||||
|
||||
|
||||
class Solver(BaseComponent):
|
||||
model: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
examples: Optional[Union[str, List[str]]] = None
|
||||
|
||||
def _compose_fewshot_prompt(self) -> str:
|
||||
if self.examples is None:
|
||||
return ""
|
||||
if isinstance(self.examples, str):
|
||||
return self.examples
|
||||
else:
|
||||
return "\n\n".join([e.strip("\n") for e in self.examples])
|
||||
|
||||
def _compose_prompt(self, instruction, plan_evidence) -> str:
|
||||
"""
|
||||
Compose the prompt from template, plan&evidence, examples and instruction.
|
||||
"""
|
||||
fewshot = self._compose_fewshot_prompt()
|
||||
if self.prompt_template is not None:
|
||||
if "fewshot" in self.prompt_template.placeholders:
|
||||
return self.prompt_template.populate(
|
||||
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
||||
)
|
||||
else:
|
||||
return self.prompt_template.populate(
|
||||
plan_evidence=plan_evidence, task=instruction
|
||||
)
|
||||
else:
|
||||
if self.examples is not None:
|
||||
return few_shot_solver_prompt.populate(
|
||||
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
||||
)
|
||||
else:
|
||||
return zero_shot_solver_prompt.populate(
|
||||
plan_evidence=plan_evidence, task=instruction
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
instruction: str,
|
||||
plan_evidence: str,
|
||||
output: BaseScratchPad = BaseScratchPad(),
|
||||
) -> Any:
|
||||
response = None
|
||||
output.info("Running Solver")
|
||||
output.debug(f"Instruction: {instruction}")
|
||||
output.debug(f"Plan Evidence: {plan_evidence}")
|
||||
prompt = self._compose_prompt(instruction, plan_evidence)
|
||||
output.debug(f"Prompt: {prompt}")
|
||||
try:
|
||||
response = self.model(prompt)
|
||||
output.info("Solver run successful.")
|
||||
except ValueError:
|
||||
output.error("Solver failed to retrieve response from LLM")
|
||||
|
||||
return response
|
6
knowledgehub/agents/tools/__init__.py
Normal file
6
knowledgehub/agents/tools/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .base import BaseTool, ComponentTool
|
||||
from .google import GoogleSearchTool
|
||||
from .llm import LLMTool
|
||||
from .wikipedia import WikipediaTool
|
||||
|
||||
__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool", "LLMTool"]
|
137
knowledgehub/agents/tools/base.py
Normal file
137
knowledgehub/agents/tools/base.py
Normal file
@@ -0,0 +1,137 @@
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain.agents import Tool as LCTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
|
||||
|
||||
class ToolException(Exception):
|
||||
"""An optional exception that tool throws when execution error occurs.
|
||||
|
||||
When this exception is thrown, the agent will not stop working,
|
||||
but will handle the exception according to the handle_tool_error
|
||||
variable of the tool, and the processing result will be returned
|
||||
to the agent as observation, and printed in red on the console.
|
||||
"""
|
||||
|
||||
|
||||
class BaseTool(BaseComponent):
|
||||
name: str
|
||||
"""The unique name of the tool that clearly communicates its purpose."""
|
||||
description: str
|
||||
"""Description used to tell the model how/when/why to use the tool.
|
||||
You can provide few-shot examples as a part of the description. This will be
|
||||
input to the prompt of LLM.
|
||||
"""
|
||||
args_schema: Optional[Type[BaseModel]] = None
|
||||
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||
verbose: bool = False
|
||||
"""Whether to log the tool's progress."""
|
||||
handle_tool_error: Optional[
|
||||
Union[bool, str, Callable[[ToolException], str]]
|
||||
] = False
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
|
||||
def _parse_input(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
"""Convert tool input to pydantic model."""
|
||||
args_schema = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if args_schema is not None:
|
||||
key_ = next(iter(args_schema.__fields__.keys()))
|
||||
args_schema.validate({key_: tool_input})
|
||||
return tool_input
|
||||
else:
|
||||
if args_schema is not None:
|
||||
result = args_schema.parse_obj(tool_input)
|
||||
return {k: v for k, v in result.dict().items() if k in tool_input}
|
||||
return tool_input
|
||||
|
||||
def _run_tool(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Call tool."""
|
||||
raise NotImplementedError(f"_run_tool is not implemented for {self.name}")
|
||||
|
||||
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||
# For backwards compatibility, if run_input is a string,
|
||||
# pass as a positional argument.
|
||||
if isinstance(tool_input, str):
|
||||
return (tool_input,), {}
|
||||
else:
|
||||
return (), tool_input
|
||||
|
||||
def _handle_tool_error(self, e: ToolException) -> Any:
|
||||
"""Handle the content of the ToolException thrown."""
|
||||
observation = None
|
||||
if not self.handle_tool_error:
|
||||
raise e
|
||||
elif isinstance(self.handle_tool_error, bool):
|
||||
if e.args:
|
||||
observation = e.args[0]
|
||||
else:
|
||||
observation = "Tool execution error"
|
||||
elif isinstance(self.handle_tool_error, str):
|
||||
observation = self.handle_tool_error
|
||||
elif callable(self.handle_tool_error):
|
||||
observation = self.handle_tool_error(e)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||
f"or callable. Received: {self.handle_tool_error}"
|
||||
)
|
||||
return observation
|
||||
|
||||
def to_langchain_format(self) -> LCTool:
|
||||
"""Convert this tool to Langchain format to use with its agent"""
|
||||
return LCTool(name=self.name, description=self.description, func=self.run)
|
||||
|
||||
def run(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
verbose: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Run the tool."""
|
||||
parsed_input = self._parse_input(tool_input)
|
||||
# TODO (verbose_): Add logging
|
||||
try:
|
||||
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||
call_kwargs = {**kwargs, **tool_kwargs}
|
||||
observation = self._run_tool(*tool_args, **call_kwargs)
|
||||
except ToolException as e:
|
||||
observation = self._handle_tool_error(e)
|
||||
return observation
|
||||
else:
|
||||
return observation
|
||||
|
||||
@classmethod
|
||||
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
|
||||
"""Wrapper for Langchain Tool"""
|
||||
new_tool = BaseTool(
|
||||
name=langchain_tool.name, description=langchain_tool.description
|
||||
)
|
||||
new_tool._run_tool = langchain_tool._run # type: ignore
|
||||
return new_tool
|
||||
|
||||
|
||||
class ComponentTool(BaseTool):
|
||||
"""
|
||||
A Tool based on another pipeline / BaseComponent to be used
|
||||
as its main entry point
|
||||
"""
|
||||
|
||||
component: BaseComponent
|
||||
postprocessor: Optional[Callable] = None
|
||||
|
||||
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
|
||||
output = self.component(*args, **kwargs)
|
||||
if self.postprocessor:
|
||||
output = self.postprocessor(output)
|
||||
|
||||
return output
|
51
knowledgehub/agents/tools/google.py
Normal file
51
knowledgehub/agents/tools/google.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import AnyStr, Optional, Type
|
||||
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class GoogleSearchArgs(BaseModel):
|
||||
query: str = Field(..., description="a search query")
|
||||
|
||||
|
||||
class GoogleSearchTool(BaseTool):
|
||||
name = "google_search"
|
||||
description = (
|
||||
"A search engine retrieving top search results as snippets from Google. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
try:
|
||||
from googlesearch import search
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"install googlesearch using `pip3 install googlesearch-python` to "
|
||||
"use this tool"
|
||||
)
|
||||
output = ""
|
||||
search_results = search(query, advanced=True)
|
||||
if search_results:
|
||||
output = "\n".join(
|
||||
"{} {}".format(item.title, item.description) for item in search_results
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class SerpTool(BaseTool):
|
||||
name = "google_search"
|
||||
description = (
|
||||
"Worker that searches results from Google. Useful when you need to find short "
|
||||
"and succinct answers about a specific topic. Input should be a search query."
|
||||
)
|
||||
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
tool = SerpAPIWrapper()
|
||||
evidence = tool.run(query)
|
||||
|
||||
return evidence
|
34
knowledgehub/agents/tools/llm.py
Normal file
34
knowledgehub/agents/tools/llm.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import AnyStr, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kotaemon.llms import LLM, AzureChatOpenAI, ChatLLM
|
||||
|
||||
from .base import BaseTool, ToolException
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
class LLMArgs(BaseModel):
|
||||
query: str = Field(..., description="a search question or prompt")
|
||||
|
||||
|
||||
class LLMTool(BaseTool):
|
||||
name = "llm"
|
||||
description = (
|
||||
"A pretrained LLM like yourself. Useful when you need to act with "
|
||||
"general world knowledge and common sense. Prioritize it when you "
|
||||
"are confident in solving the problem "
|
||||
"yourself. Input can be any instruction."
|
||||
)
|
||||
llm: BaseLLM = AzureChatOpenAI()
|
||||
args_schema: Optional[Type[BaseModel]] = LLMArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
output = None
|
||||
try:
|
||||
response = self.llm(query)
|
||||
except ValueError:
|
||||
raise ToolException("LLM Tool call failed")
|
||||
output = response.text
|
||||
return output
|
66
knowledgehub/agents/tools/wikipedia.py
Normal file
66
knowledgehub/agents/tools/wikipedia.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from typing import Any, AnyStr, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class Wiki:
|
||||
"""Wrapper around wikipedia API."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Check that wikipedia package is installed."""
|
||||
try:
|
||||
import wikipedia # noqa: F401
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import wikipedia python package. "
|
||||
"Please install it with `pip install wikipedia`."
|
||||
)
|
||||
|
||||
def search(self, search: str) -> Union[str, Document]:
|
||||
"""Try to search for wiki page.
|
||||
|
||||
If page exists, return the page summary, and a PageWithLookups object.
|
||||
If page does not exist, return similar entries.
|
||||
"""
|
||||
import wikipedia
|
||||
|
||||
try:
|
||||
page_content = wikipedia.page(search).content
|
||||
url = wikipedia.page(search).url
|
||||
result: Union[str, Document] = Document(
|
||||
text=page_content, metadata={"page": url}
|
||||
)
|
||||
except wikipedia.PageError:
|
||||
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
|
||||
except wikipedia.DisambiguationError:
|
||||
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
|
||||
return result
|
||||
|
||||
|
||||
class WikipediaArgs(BaseModel):
|
||||
query: str = Field(..., description="a search query as input to wkipedia")
|
||||
|
||||
|
||||
class WikipediaTool(BaseTool):
|
||||
"""Tool that adds the capability to query the Wikipedia API."""
|
||||
|
||||
name = "wikipedia"
|
||||
description = (
|
||||
"Search engine from Wikipedia, retrieving relevant wiki page. "
|
||||
"Useful when you need to get holistic knowledge about people, "
|
||||
"places, companies, historical events, or other subjects. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
args_schema: Optional[Type[BaseModel]] = WikipediaArgs
|
||||
doc_store: Any = None
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> AnyStr:
|
||||
if not self.doc_store:
|
||||
self.doc_store = Wiki()
|
||||
tool = self.doc_store
|
||||
evidence = tool.search(query)
|
||||
return evidence
|
22
knowledgehub/agents/utils.py
Normal file
22
knowledgehub/agents/utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from kotaemon.base import Document
|
||||
|
||||
|
||||
def get_plugin_response_content(output) -> str:
|
||||
"""
|
||||
Wrapper for AgentOutput content return
|
||||
"""
|
||||
if isinstance(output, Document):
|
||||
return output.text
|
||||
else:
|
||||
return str(output)
|
||||
|
||||
|
||||
def calculate_cost(model_name: str, prompt_token: int, completion_token: int) -> float:
|
||||
"""
|
||||
Calculate the cost of a prompt and completion.
|
||||
|
||||
Returns:
|
||||
float: Cost of the provided model name with provided token information
|
||||
"""
|
||||
# TODO: to be implemented
|
||||
return 0.0
|
Reference in New Issue
Block a user