Refactor agents and tools (#91)
* Move tools to agents * Move agents to dedicate place * Remove subclassing BaseAgent from BaseTool
This commit is contained in:
parent
4256030b4f
commit
8e3a1d193f
68
knowledgehub/agents/base.py
Normal file
68
knowledgehub/agents/base.py
Normal file
|
@ -0,0 +1,68 @@
|
|||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.base import ChatLLM
|
||||
from kotaemon.llms.completions.base import LLM
|
||||
|
||||
from .tools import BaseTool
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
class AgentType(Enum):
|
||||
"""
|
||||
Enumerated type for agent types.
|
||||
"""
|
||||
|
||||
openai = "openai"
|
||||
openai_multi = "openai_multi"
|
||||
openai_tool = "openai_tool"
|
||||
self_ask = "self_ask"
|
||||
react = "react"
|
||||
rewoo = "rewoo"
|
||||
vanilla = "vanilla"
|
||||
|
||||
@staticmethod
|
||||
def get_agent_class(_type: "AgentType"):
|
||||
"""
|
||||
Get agent class from agent type.
|
||||
:param _type: agent type
|
||||
:return: agent class
|
||||
"""
|
||||
if _type == AgentType.rewoo:
|
||||
from .rewoo.agent import RewooAgent
|
||||
|
||||
return RewooAgent
|
||||
else:
|
||||
raise ValueError(f"Unknown agent type: {_type}")
|
||||
|
||||
|
||||
class BaseAgent(BaseComponent):
|
||||
"""Define base agent interface"""
|
||||
|
||||
name: str = Param(help="Name of the agent.")
|
||||
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
|
||||
description: str = Param(
|
||||
help="Description used to tell the model how/when/why to use the agent. "
|
||||
"You can provide few-shot examples as a part of the description. This will be "
|
||||
"input to the prompt of LLM."
|
||||
)
|
||||
llm: Union[BaseLLM, dict[str, BaseLLM]] = Node(
|
||||
help="Specify LLM to be used in the model, cam be a dict to supply different "
|
||||
"LLMs to multiple purposes in the agent"
|
||||
)
|
||||
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
|
||||
help="A prompt template or a dict to supply different prompt to the agent"
|
||||
)
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [],
|
||||
help="List of plugins / tools to be used in the agent",
|
||||
)
|
||||
|
||||
def add_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Helper method to add tools and update agent state if needed"""
|
||||
self.plugins.extend(tools)
|
|
@ -1,14 +1,13 @@
|
|||
from typing import List, Optional, Type
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.agents import AgentType as LCAgentType
|
||||
from langchain.agents import initialize_agent
|
||||
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.llms.chats.base import ChatLLM
|
||||
from kotaemon.llms.completions.base import LLM
|
||||
from kotaemon.pipelines.tools import BaseTool
|
||||
|
||||
from .base import AgentType, BaseAgent
|
||||
|
||||
|
@ -19,9 +18,6 @@ class LangchainAgent(BaseAgent):
|
|||
name: str = "LangchainAgent"
|
||||
agent_type: AgentType
|
||||
description: str = "LangchainAgent for answering multi-step reasoning questions"
|
||||
args_schema: Optional[Type[BaseModel]] = create_model(
|
||||
"LangchainArgsSchema", instruction=(str, ...)
|
||||
)
|
||||
AGENT_TYPE_MAP = {
|
||||
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
|
||||
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
|
||||
|
@ -69,7 +65,7 @@ class LangchainAgent(BaseAgent):
|
|||
self.update_agent_tools()
|
||||
return
|
||||
|
||||
def _run_tool(self, instruction: str) -> Document:
|
||||
def run(self, instruction: str) -> Document:
|
||||
assert (
|
||||
self.agent is not None
|
||||
), "Lanchain AgentExecutor is not correclty initialized"
|
|
@ -1,8 +1,8 @@
|
|||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
from theflow import Param
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
@ -22,15 +22,18 @@ class ReactAgent(BaseAgent):
|
|||
name: str = "ReactAgent"
|
||||
agent_type: AgentType = AgentType.react
|
||||
description: str = "ReactAgent for answering multi-step reasoning questions"
|
||||
llm: Union[BaseLLM, Dict[str, BaseLLM]]
|
||||
llm: BaseLLM | dict[str, BaseLLM]
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
plugins: List[BaseTool] = list()
|
||||
examples: Dict[str, Union[str, List[str]]] = dict()
|
||||
args_schema: Optional[Type[BaseModel]] = create_model(
|
||||
"ReactArgsSchema", instruction=(str, ...)
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
||||
)
|
||||
examples: dict[str, str | list[str]] = Param(
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent. "
|
||||
)
|
||||
intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = Param(
|
||||
default_callback=lambda _: [],
|
||||
help="List of AgentAction and observation (tool) output",
|
||||
)
|
||||
intermediate_steps: List[Tuple[Union[AgentAction, AgentFinish], str]] = []
|
||||
"""List of AgentAction and observation (tool) output"""
|
||||
max_iterations = 10
|
||||
strict_decode: bool = False
|
||||
|
||||
|
@ -51,7 +54,7 @@ class ReactAgent(BaseAgent):
|
|||
return prompt
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: List[Tuple[Union[AgentAction, AgentFinish], str]] = []
|
||||
self, intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = []
|
||||
) -> str:
|
||||
"""Construct the scratchpad that lets the agent continue its thought process."""
|
||||
thoughts = ""
|
||||
|
@ -60,7 +63,7 @@ class ReactAgent(BaseAgent):
|
|||
thoughts += f"\nObservation: {observation}\nThought:"
|
||||
return thoughts
|
||||
|
||||
def _parse_output(self, text: str) -> Optional[Union[AgentAction, AgentFinish]]:
|
||||
def _parse_output(self, text: str) -> Optional[AgentAction | AgentFinish]:
|
||||
"""
|
||||
Parse text output from LLM for the next Action or Final Answer
|
||||
Using Regex to parse "Action:\n Action Input:\n" for the next Action
|
||||
|
@ -74,7 +77,7 @@ class ReactAgent(BaseAgent):
|
|||
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
|
||||
)
|
||||
action_match = re.search(regex, text, re.DOTALL)
|
||||
action_output: Optional[Union[AgentAction, AgentFinish]] = None
|
||||
action_output: Optional[AgentAction | AgentFinish] = None
|
||||
if action_match:
|
||||
if includes_answer:
|
||||
raise Exception(
|
||||
|
@ -120,7 +123,7 @@ class ReactAgent(BaseAgent):
|
|||
tool_names=tool_names,
|
||||
)
|
||||
|
||||
def _format_function_map(self) -> Dict[str, BaseTool]:
|
||||
def _format_function_map(self) -> dict[str, BaseTool]:
|
||||
"""Format the function map for the open AI function API.
|
||||
|
||||
Return:
|
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
from theflow import Param
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
||||
|
@ -23,16 +23,16 @@ class RewooAgent(BaseAgent):
|
|||
name: str = "RewooAgent"
|
||||
agent_type: AgentType = AgentType.rewoo
|
||||
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||
llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx}
|
||||
prompt_template: Dict[
|
||||
str, PromptTemplate
|
||||
] = dict() # {"Planner": xxx, "Solver": xxx}
|
||||
plugins: List[BaseTool] = list()
|
||||
examples: Dict[
|
||||
str, Union[str, List[str]]
|
||||
] = dict() # {"Planner": xxx, "Solver": xxx}
|
||||
args_schema: Optional[Type[BaseModel]] = create_model(
|
||||
"RewooArgsSchema", instruction=(str, ...)
|
||||
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx}
|
||||
prompt_template: dict[str, PromptTemplate] = Param(
|
||||
default_callback=lambda _: {},
|
||||
help="A dict to supply different prompt to the agent.",
|
||||
)
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="A list of plugins to be used in the model."
|
||||
)
|
||||
examples: dict[str, str | list[str]] = Param(
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||
)
|
||||
|
||||
def _get_llms(self):
|
||||
|
@ -49,7 +49,7 @@ class RewooAgent(BaseAgent):
|
|||
|
||||
def _parse_plan_map(
|
||||
self, planner_response: str
|
||||
) -> Tuple[Dict[str, List[str]], Dict[str, str]]:
|
||||
) -> tuple[dict[str, list[str]], dict[str, str]]:
|
||||
"""
|
||||
Parse planner output. It should be an n-to-n mapping from Plans to #Es.
|
||||
This is because sometimes LLM cannot follow the strict output format.
|
||||
|
@ -66,7 +66,7 @@ class RewooAgent(BaseAgent):
|
|||
This function should also return a plan map.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
|
||||
tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
|
||||
"""
|
||||
valid_chunk = [
|
||||
line
|
||||
|
@ -74,8 +74,8 @@ class RewooAgent(BaseAgent):
|
|||
if line.startswith("#Plan") or line.startswith("#E")
|
||||
]
|
||||
|
||||
plan_to_es: Dict[str, List[str]] = dict()
|
||||
plans: Dict[str, str] = dict()
|
||||
plan_to_es: dict[str, list[str]] = dict()
|
||||
plans: dict[str, str] = dict()
|
||||
for line in valid_chunk:
|
||||
if line.startswith("#Plan"):
|
||||
plan = line.split(":", 1)[0].strip()
|
||||
|
@ -88,7 +88,7 @@ class RewooAgent(BaseAgent):
|
|||
|
||||
def _parse_planner_evidences(
|
||||
self, planner_response: str
|
||||
) -> Tuple[Dict[str, str], List[List[str]]]:
|
||||
) -> tuple[dict[str, str], list[list[str]]]:
|
||||
"""
|
||||
Parse planner output. This should return a mapping from #E to tool call.
|
||||
It should also identify the level of each #E in dependency map.
|
||||
|
@ -99,11 +99,11 @@ class RewooAgent(BaseAgent):
|
|||
}, [[#E1, #E2], [#E3, #E4]]
|
||||
|
||||
Returns:
|
||||
Tuple[dict[str, str], List[List[str]]]:
|
||||
tuple[dict[str, str], List[List[str]]]:
|
||||
A mapping from #E to tool call and a list of levels.
|
||||
"""
|
||||
evidences: Dict[str, str] = dict()
|
||||
dependence: Dict[str, List[str]] = dict()
|
||||
evidences: dict[str, str] = dict()
|
||||
dependence: dict[str, list[str]] = dict()
|
||||
for line in planner_response.splitlines():
|
||||
if line.startswith("#E") and line[2].isdigit():
|
||||
e, tool_call = line.split(":", 1)
|
||||
|
@ -134,8 +134,8 @@ class RewooAgent(BaseAgent):
|
|||
def _run_plugin(
|
||||
self,
|
||||
e: str,
|
||||
planner_evidences: Dict[str, str],
|
||||
worker_evidences: Dict[str, str],
|
||||
planner_evidences: dict[str, str],
|
||||
worker_evidences: dict[str, str],
|
||||
output=BaseScratchPad(),
|
||||
):
|
||||
"""
|
||||
|
@ -169,8 +169,8 @@ class RewooAgent(BaseAgent):
|
|||
|
||||
def _get_worker_evidence(
|
||||
self,
|
||||
planner_evidences: Dict[str, str],
|
||||
evidences_level: List[List[str]],
|
||||
planner_evidences: dict[str, str],
|
||||
evidences_level: list[list[str]],
|
||||
output=BaseScratchPad(),
|
||||
) -> Any:
|
||||
"""
|
||||
|
@ -185,7 +185,7 @@ class RewooAgent(BaseAgent):
|
|||
Returns:
|
||||
A mapping from #E to tool call.
|
||||
"""
|
||||
worker_evidences: Dict[str, str] = dict()
|
||||
worker_evidences: dict[str, str] = dict()
|
||||
plugin_cost, plugin_token = 0.0, 0.0
|
||||
with ThreadPoolExecutor() as pool:
|
||||
for level in evidences_level:
|
||||
|
@ -218,7 +218,7 @@ class RewooAgent(BaseAgent):
|
|||
if p.name == name:
|
||||
return p
|
||||
|
||||
def _run_tool(self, instruction: str, use_citation: bool = False) -> Document:
|
||||
def run(self, instruction: str, use_citation: bool = False) -> Document:
|
||||
"""
|
||||
Run the agent with a given instruction.
|
||||
"""
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
from ....base import BaseComponent
|
||||
from ....llms import PromptTemplate
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from ..base import BaseLLM, BaseTool
|
||||
from ..output.base import BaseScratchPad
|
||||
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
|
|
@ -1,4 +1,4 @@
|
|||
from ...base import Document
|
||||
from kotaemon.base import Document
|
||||
|
||||
|
||||
def get_plugin_response_content(output) -> str:
|
|
@ -1,61 +0,0 @@
|
|||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.base import ChatLLM
|
||||
from kotaemon.llms.completions.base import LLM
|
||||
from kotaemon.pipelines.tools import BaseTool
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
class AgentType(Enum):
|
||||
"""
|
||||
Enumerated type for agent types.
|
||||
"""
|
||||
|
||||
openai = "openai"
|
||||
openai_multi = "openai_multi"
|
||||
openai_tool = "openai_tool"
|
||||
self_ask = "self_ask"
|
||||
react = "react"
|
||||
rewoo = "rewoo"
|
||||
vanilla = "vanilla"
|
||||
|
||||
@staticmethod
|
||||
def get_agent_class(_type: "AgentType"):
|
||||
"""
|
||||
Get agent class from agent type.
|
||||
:param _type: agent type
|
||||
:return: agent class
|
||||
"""
|
||||
if _type == AgentType.rewoo:
|
||||
from .rewoo.agent import RewooAgent
|
||||
|
||||
return RewooAgent
|
||||
else:
|
||||
raise ValueError(f"Unknown agent type: {_type}")
|
||||
|
||||
|
||||
class BaseAgent(BaseTool):
|
||||
name: str
|
||||
"""Name of the agent."""
|
||||
agent_type: AgentType
|
||||
"""Agent type, must be one of AgentType"""
|
||||
description: str
|
||||
"""Description used to tell the model how/when/why to use the agent.
|
||||
You can provide few-shot examples as a part of the description. This will be
|
||||
input to the prompt of LLM."""
|
||||
llm: Union[BaseLLM, Dict[str, BaseLLM]]
|
||||
"""Specify LLM to be used in the model, cam be a dict to supply different
|
||||
LLMs to multiple purposes in the agent"""
|
||||
prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]]
|
||||
"""A prompt template or a dict to supply different prompt to the agent
|
||||
"""
|
||||
plugins: List[BaseTool] = []
|
||||
"""List of plugins / tools to be used in the agent
|
||||
"""
|
||||
|
||||
def add_tools(self, tools: List[BaseTool]) -> None:
|
||||
"""Helper method to add tools and update agent state if needed"""
|
||||
self.plugins.extend(tools)
|
|
@ -8,6 +8,7 @@ from llama_index.readers.base import BaseReader
|
|||
from theflow import Node
|
||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.agents import BaseAgent
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.indices.extractors import BaseDocParser
|
||||
|
@ -20,7 +21,6 @@ from kotaemon.loaders import (
|
|||
OCRReader,
|
||||
PandasExcelReader,
|
||||
)
|
||||
from kotaemon.pipelines.agents import BaseAgent
|
||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.storages import (
|
||||
|
|
|
@ -5,16 +5,16 @@ from typing import List, Sequence
|
|||
from theflow import Node
|
||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.agents import BaseAgent
|
||||
from kotaemon.agents.tools import ComponentTool
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base.schema import Document, RetrievedDocument
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.indices.rankings import BaseReranking
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.pipelines.agents import BaseAgent
|
||||
from kotaemon.pipelines.citation import CitationPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.pipelines.tools import ComponentTool
|
||||
from kotaemon.storages import (
|
||||
BaseDocumentStore,
|
||||
BaseVectorStore,
|
||||
|
|
|
@ -3,17 +3,12 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from kotaemon.agents.base import AgentType
|
||||
from kotaemon.agents.langchain import LangchainAgent
|
||||
from kotaemon.agents.react import ReactAgent
|
||||
from kotaemon.agents.rewoo import RewooAgent
|
||||
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.pipelines.agents.base import AgentType
|
||||
from kotaemon.pipelines.agents.langchain import LangchainAgent
|
||||
from kotaemon.pipelines.agents.react import ReactAgent
|
||||
from kotaemon.pipelines.agents.rewoo import RewooAgent
|
||||
from kotaemon.pipelines.tools import (
|
||||
BaseTool,
|
||||
GoogleSearchTool,
|
||||
LLMTool,
|
||||
WikipediaTool,
|
||||
)
|
||||
|
||||
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
|
||||
|
||||
|
|
|
@ -4,11 +4,11 @@ from pathlib import Path
|
|||
import pytest
|
||||
from openai.resources.embeddings import Embeddings
|
||||
|
||||
from kotaemon.agents.tools import ComponentTool, GoogleSearchTool, WikipediaTool
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.pipelines.tools import ComponentTool, GoogleSearchTool, WikipediaTool
|
||||
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
|
||||
|
||||
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
||||
|
|
Loading…
Reference in New Issue
Block a user