* refractor agents * minor cosmetic, add terminal ui for cli * pump to 0.3.4 * Add temporary path * fix unclose files in tests --------- Co-authored-by: trducng <trungduc1992@gmail.com>
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
from typing import List, Optional
|
|
|
|
from langchain.agents import AgentType as LCAgentType
|
|
from langchain.agents import initialize_agent
|
|
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
|
|
|
from kotaemon.llms import LLM, ChatLLM
|
|
|
|
from .base import BaseAgent
|
|
from .io import AgentOutput, AgentType
|
|
from .tools import BaseTool
|
|
|
|
|
|
class LangchainAgent(BaseAgent):
|
|
"""Wrapper for Langchain Agent"""
|
|
|
|
name: str = "LangchainAgent"
|
|
agent_type: AgentType
|
|
description: str = "LangchainAgent for answering multi-step reasoning questions"
|
|
AGENT_TYPE_MAP = {
|
|
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
|
|
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
|
|
AgentType.react: LCAgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
|
AgentType.self_ask: LCAgentType.SELF_ASK_WITH_SEARCH,
|
|
}
|
|
agent: Optional[LCAgentExecutor] = None
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
if self.agent_type not in self.AGENT_TYPE_MAP:
|
|
raise NotImplementedError(
|
|
f"AgentType {self.agent_type } not supported by Langchain wrapper"
|
|
)
|
|
self.update_agent_tools()
|
|
|
|
def update_agent_tools(self):
|
|
assert isinstance(self.llm, (ChatLLM, LLM))
|
|
langchain_plugins = [tool.to_langchain_format() for tool in self.plugins]
|
|
|
|
# a fix for search_doc tool name:
|
|
# use "Intermediate Answer" for self-ask agent
|
|
found_search_tool = False
|
|
if self.agent_type == AgentType.self_ask:
|
|
for plugin in langchain_plugins:
|
|
if plugin.name == "search_doc":
|
|
plugin.name = "Intermediate Answer"
|
|
langchain_plugins = [plugin]
|
|
found_search_tool = True
|
|
break
|
|
|
|
if self.agent_type != AgentType.self_ask or found_search_tool:
|
|
# reinit Langchain AgentExecutor
|
|
self.agent = initialize_agent(
|
|
langchain_plugins,
|
|
# TODO: could cause bugs for non-langchain llms
|
|
# related to https://github.com/Cinnamon/kotaemon/issues/73
|
|
self.llm._obj, # type: ignore
|
|
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
|
handle_parsing_errors=True,
|
|
verbose=True,
|
|
)
|
|
|
|
def add_tools(self, tools: List[BaseTool]) -> None:
|
|
super().add_tools(tools)
|
|
self.update_agent_tools()
|
|
return
|
|
|
|
def run(self, instruction: str) -> AgentOutput:
|
|
assert (
|
|
self.agent is not None
|
|
), "Lanchain AgentExecutor is not correclty initialized"
|
|
|
|
# Langchain AgentExecutor call
|
|
output = self.agent(instruction)["output"]
|
|
|
|
return AgentOutput(
|
|
text=output,
|
|
agent_type=self.agent_type,
|
|
status="finished",
|
|
)
|