refractor agents (#100)

* refractor agents

* minor cosmetic, add terminal ui for cli

* pump to 0.3.4

* Add temporary path

* fix unclose files in tests

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin
2023-12-06 17:06:29 +07:00
committed by GitHub
parent d9e925eb75
commit 797df5a69c
21 changed files with 281 additions and 228 deletions

View File

@@ -3,15 +3,15 @@ import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from theflow import Param
from theflow import Node, Param
from kotaemon.base.schema import Document
from kotaemon.agents.base import BaseAgent
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
from kotaemon.agents.tools import BaseTool
from kotaemon.agents.utils import get_plugin_response_content
from kotaemon.indices.qa import CitationPipeline
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
from kotaemon.llms import BaseLLM, PromptTemplate
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from ..utils import get_plugin_response_content
from .planner import Planner
from .solver import Solver
@@ -23,7 +23,8 @@ class RewooAgent(BaseAgent):
name: str = "RewooAgent"
agent_type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions"
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx}
planner_llm: BaseLLM
solver_llm: BaseLLM
prompt_template: dict[str, PromptTemplate] = Param(
default_callback=lambda _: {},
help="A dict to supply different prompt to the agent.",
@@ -35,17 +36,22 @@ class RewooAgent(BaseAgent):
default_callback=lambda _: {}, help="Examples to be used in the agent."
)
def _get_llms(self):
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM):
return {"Planner": self.llm, "Solver": self.llm}
elif (
isinstance(self.llm, dict)
and "Planner" in self.llm
and "Solver" in self.llm
):
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
else:
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
def planner(self):
return Planner(
model=self.planner_llm,
plugins=self.plugins,
prompt_template=self.prompt_template.get("Planner", None),
examples=self.examples.get("Planner", None),
)
@Node.auto(depends_on=["solver_llm", "prompt_template", "examples"])
def solver(self):
return Solver(
model=self.solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
)
def _parse_plan_map(
self, planner_response: str
@@ -76,13 +82,16 @@ class RewooAgent(BaseAgent):
plan_to_es: dict[str, list[str]] = dict()
plans: dict[str, str] = dict()
prev_key = ""
for line in valid_chunk:
if line.startswith("#Plan"):
plan = line.split(":", 1)[0].strip()
plans[plan] = line.split(":", 1)[1].strip()
plan_to_es[plan] = []
elif line.startswith("#E"):
plan_to_es[plan].append(line.split(":", 1)[0].strip())
key, description = line.split(":", 1)
key = key.strip()
if key.startswith("#Plan"):
plans[key] = description.strip()
plan_to_es[key] = []
prev_key = key
elif key.startswith("#E"):
plan_to_es[prev_key].append(key)
return plan_to_es, plans
@@ -218,7 +227,8 @@ class RewooAgent(BaseAgent):
if p.name == name:
return p
def run(self, instruction: str, use_citation: bool = False) -> Document:
@BaseAgent.safeguard_run
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
"""
Run the agent with a given instruction.
"""
@@ -226,27 +236,12 @@ class RewooAgent(BaseAgent):
total_cost = 0.0
total_token = 0
planner_llm = self._get_llms()["Planner"]
solver_llm = self._get_llms()["Solver"]
planner = Planner(
model=planner_llm,
plugins=self.plugins,
prompt_template=self.prompt_template.get("Planner", None),
examples=self.examples.get("Planner", None),
)
solver = Solver(
model=solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
)
# Plan
planner_output = planner(instruction)
plannner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
planner_output = self.planner(instruction)
planner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(planner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences(
plannner_text_output
planner_text_output
)
# Work
@@ -260,20 +255,19 @@ class RewooAgent(BaseAgent):
worker_log += f"{e}: {worker_evidences[e]}\n"
# Solve
solver_output = solver(instruction, worker_log)
solver_output = self.solver(instruction, worker_log)
solver_output_text = solver_output.text
if use_citation:
citation_pipeline = CitationPipeline(llm=solver_llm)
citation_pipeline = CitationPipeline(llm=self.solver_llm)
citation = citation_pipeline(context=worker_log, question=instruction)
else:
citation = None
return Document(
return AgentOutput(
text=solver_output_text,
metadata={
"agent": "react",
"cost": total_cost,
"usage": total_token,
"citation": citation,
},
agent_type=self.agent_type,
status="finished",
total_tokens=total_token,
total_cost=total_cost,
citation=citation,
)

View File

@@ -1,10 +1,10 @@
from typing import Any, List, Optional, Union
from kotaemon.agents.base import BaseLLM, BaseTool
from kotaemon.agents.io import BaseScratchPad
from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate
from ..base import BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt

View File

@@ -1,10 +1,9 @@
from typing import Any, List, Optional, Union
from kotaemon.agents.io import BaseScratchPad
from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate
from kotaemon.llms import BaseLLM, PromptTemplate
from ..base import BaseLLM
from ..output.base import BaseScratchPad
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt