refractor agents (#100)

* refractor agents

* minor cosmetic, add terminal ui for cli

* pump to 0.3.4

* Add temporary path

* fix unclose files in tests

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2023-12-06 17:06:29 +07:00
committed by GitHub
parent d9e925eb75
commit 797df5a69c
21 changed files with 281 additions and 228 deletions

View File

@@ -4,12 +4,11 @@ from typing import Optional
from theflow import Param
from kotaemon.base.schema import Document
from kotaemon.agents.base import BaseAgent, BaseLLM
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
from kotaemon.agents.tools import BaseTool
from kotaemon.llms import PromptTemplate
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import AgentAction, AgentFinish
FINAL_ANSWER_ACTION = "Final Answer:"
@@ -22,7 +21,7 @@ class ReactAgent(BaseAgent):
name: str = "ReactAgent"
agent_type: AgentType = AgentType.react
description: str = "ReactAgent for answering multi-step reasoning questions"
llm: BaseLLM | dict[str, BaseLLM]
llm: BaseLLM
prompt_template: Optional[PromptTemplate] = None
plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="List of tools to be used in the agent. "
@@ -34,7 +33,7 @@ class ReactAgent(BaseAgent):
default_callback=lambda _: [],
help="List of AgentAction and observation (tool) output",
)
max_iterations = 10
max_iterations: int = 10
strict_decode: bool = False
def _compose_plugin_description(self) -> str:
@@ -141,7 +140,7 @@ class ReactAgent(BaseAgent):
"""
self.intermediate_steps = []
def run(self, instruction, max_iterations=None):
def run(self, instruction, max_iterations=None) -> AgentOutput:
"""
Run the agent with the given instruction.
@@ -161,11 +160,15 @@ class ReactAgent(BaseAgent):
logging.info(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
status = "failed"
response_text = None
for _ in range(max_iterations):
for step_count in range(1, max_iterations + 1):
prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}")
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
response = self.llm(
prompt, stop=["Observation:"]
) # could cause bugs if llm doesn't have `stop` as a parameter
response_text = response.text
logging.info(f"Response: {response_text}")
action_step = self._parse_output(response_text)
@@ -185,13 +188,18 @@ class ReactAgent(BaseAgent):
self.intermediate_steps.append((action_step, result))
if is_finished_chain:
logging.info(f"Finished after {step_count} steps.")
status = "finished"
break
else:
status = "stopped"
return Document(
return AgentOutput(
text=response_text,
metadata={
"agent": "react",
"cost": total_cost,
"usage": total_token,
},
agent_type=self.agent_type,
status=status,
total_tokens=total_token,
total_cost=total_cost,
intermediate_steps=self.intermediate_steps,
max_iterations=max_iterations,
)