Add Langchain Agent wrapper with OpenAI Function / Self-ask agent support (#82)
* update Param() type hint in MVP * update default embedding endpoint * update Langchain agent wrapper * update langchain agent
This commit is contained in:
committed by
GitHub
parent
0a3fc4b228
commit
8bb7ad91e0
85
knowledgehub/pipelines/agents/langchain.py
Normal file
85
knowledgehub/pipelines/agents/langchain.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from langchain.agents import AgentType as LCAgentType
|
||||
from langchain.agents import initialize_agent
|
||||
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.llms.chats.base import ChatLLM
|
||||
from kotaemon.llms.completions.base import LLM
|
||||
from kotaemon.pipelines.tools import BaseTool
|
||||
|
||||
from .base import AgentType, BaseAgent
|
||||
|
||||
|
||||
class LangchainAgent(BaseAgent):
|
||||
"""Wrapper for Langchain Agent"""
|
||||
|
||||
name: str = "LangchainAgent"
|
||||
agent_type: AgentType
|
||||
description: str = "LangchainAgent for answering multi-step reasoning questions"
|
||||
args_schema: Optional[Type[BaseModel]] = create_model(
|
||||
"LangchainArgsSchema", instruction=(str, ...)
|
||||
)
|
||||
AGENT_TYPE_MAP = {
|
||||
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
|
||||
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
|
||||
AgentType.react: LCAgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
AgentType.self_ask: LCAgentType.SELF_ASK_WITH_SEARCH,
|
||||
}
|
||||
agent: Optional[LCAgentExecutor] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if self.agent_type not in self.AGENT_TYPE_MAP:
|
||||
raise NotImplementedError(
|
||||
f"AgentType {self.agent_type } not supported by Langchain wrapper"
|
||||
)
|
||||
self.update_agent_tools()
|
||||
|
||||
def update_agent_tools(self):
|
||||
assert isinstance(self.llm, (ChatLLM, LLM))
|
||||
langchain_plugins = [tool.to_langchain_format() for tool in self.plugins]
|
||||
|
||||
# a fix for search_doc tool name:
|
||||
# use "Intermediate Answer" for self-ask agent
|
||||
found_search_tool = False
|
||||
if self.agent_type == AgentType.self_ask:
|
||||
for plugin in langchain_plugins:
|
||||
if plugin.name == "search_doc":
|
||||
plugin.name = "Intermediate Answer"
|
||||
langchain_plugins = [plugin]
|
||||
found_search_tool = True
|
||||
break
|
||||
|
||||
if self.agent_type != AgentType.self_ask or found_search_tool:
|
||||
# reinit Langchain AgentExecutor
|
||||
self.agent = initialize_agent(
|
||||
langchain_plugins,
|
||||
self.llm.agent,
|
||||
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
||||
handle_parsing_errors=True,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
def add_tools(self, tools: List[BaseTool]) -> None:
|
||||
super().add_tools(tools)
|
||||
self.update_agent_tools()
|
||||
return
|
||||
|
||||
def _run_tool(self, instruction: str) -> Document:
|
||||
assert (
|
||||
self.agent is not None
|
||||
), "Lanchain AgentExecutor is not correclty initialized"
|
||||
# Langchain AgentExecutor call
|
||||
output = self.agent(instruction)["output"]
|
||||
return Document(
|
||||
text=output,
|
||||
metadata={
|
||||
"agent": "langchain",
|
||||
"cost": 0.0,
|
||||
"usage": 0,
|
||||
},
|
||||
)
|
Reference in New Issue
Block a user