refractor agents (#100)
* refractor agents * minor cosmetic, add terminal ui for cli * pump to 0.3.4 * Add temporary path * fix unclose files in tests --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
@@ -3,15 +3,15 @@ import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from theflow import Param
|
||||
from theflow import Node, Param
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.agents.base import BaseAgent
|
||||
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.agents.utils import get_plugin_response_content
|
||||
from kotaemon.indices.qa import CitationPipeline
|
||||
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
||||
from ..output.base import BaseScratchPad
|
||||
from ..utils import get_plugin_response_content
|
||||
from .planner import Planner
|
||||
from .solver import Solver
|
||||
|
||||
@@ -23,7 +23,8 @@ class RewooAgent(BaseAgent):
|
||||
name: str = "RewooAgent"
|
||||
agent_type: AgentType = AgentType.rewoo
|
||||
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx}
|
||||
planner_llm: BaseLLM
|
||||
solver_llm: BaseLLM
|
||||
prompt_template: dict[str, PromptTemplate] = Param(
|
||||
default_callback=lambda _: {},
|
||||
help="A dict to supply different prompt to the agent.",
|
||||
@@ -35,17 +36,22 @@ class RewooAgent(BaseAgent):
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||
)
|
||||
|
||||
def _get_llms(self):
|
||||
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM):
|
||||
return {"Planner": self.llm, "Solver": self.llm}
|
||||
elif (
|
||||
isinstance(self.llm, dict)
|
||||
and "Planner" in self.llm
|
||||
and "Solver" in self.llm
|
||||
):
|
||||
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
|
||||
else:
|
||||
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
|
||||
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
||||
def planner(self):
|
||||
return Planner(
|
||||
model=self.planner_llm,
|
||||
plugins=self.plugins,
|
||||
prompt_template=self.prompt_template.get("Planner", None),
|
||||
examples=self.examples.get("Planner", None),
|
||||
)
|
||||
|
||||
@Node.auto(depends_on=["solver_llm", "prompt_template", "examples"])
|
||||
def solver(self):
|
||||
return Solver(
|
||||
model=self.solver_llm,
|
||||
prompt_template=self.prompt_template.get("Solver", None),
|
||||
examples=self.examples.get("Solver", None),
|
||||
)
|
||||
|
||||
def _parse_plan_map(
|
||||
self, planner_response: str
|
||||
@@ -76,13 +82,16 @@ class RewooAgent(BaseAgent):
|
||||
|
||||
plan_to_es: dict[str, list[str]] = dict()
|
||||
plans: dict[str, str] = dict()
|
||||
prev_key = ""
|
||||
for line in valid_chunk:
|
||||
if line.startswith("#Plan"):
|
||||
plan = line.split(":", 1)[0].strip()
|
||||
plans[plan] = line.split(":", 1)[1].strip()
|
||||
plan_to_es[plan] = []
|
||||
elif line.startswith("#E"):
|
||||
plan_to_es[plan].append(line.split(":", 1)[0].strip())
|
||||
key, description = line.split(":", 1)
|
||||
key = key.strip()
|
||||
if key.startswith("#Plan"):
|
||||
plans[key] = description.strip()
|
||||
plan_to_es[key] = []
|
||||
prev_key = key
|
||||
elif key.startswith("#E"):
|
||||
plan_to_es[prev_key].append(key)
|
||||
|
||||
return plan_to_es, plans
|
||||
|
||||
@@ -218,7 +227,8 @@ class RewooAgent(BaseAgent):
|
||||
if p.name == name:
|
||||
return p
|
||||
|
||||
def run(self, instruction: str, use_citation: bool = False) -> Document:
|
||||
@BaseAgent.safeguard_run
|
||||
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
|
||||
"""
|
||||
Run the agent with a given instruction.
|
||||
"""
|
||||
@@ -226,27 +236,12 @@ class RewooAgent(BaseAgent):
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
|
||||
planner_llm = self._get_llms()["Planner"]
|
||||
solver_llm = self._get_llms()["Solver"]
|
||||
|
||||
planner = Planner(
|
||||
model=planner_llm,
|
||||
plugins=self.plugins,
|
||||
prompt_template=self.prompt_template.get("Planner", None),
|
||||
examples=self.examples.get("Planner", None),
|
||||
)
|
||||
solver = Solver(
|
||||
model=solver_llm,
|
||||
prompt_template=self.prompt_template.get("Solver", None),
|
||||
examples=self.examples.get("Solver", None),
|
||||
)
|
||||
|
||||
# Plan
|
||||
planner_output = planner(instruction)
|
||||
plannner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
|
||||
planner_output = self.planner(instruction)
|
||||
planner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(planner_text_output)
|
||||
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||
plannner_text_output
|
||||
planner_text_output
|
||||
)
|
||||
|
||||
# Work
|
||||
@@ -260,20 +255,19 @@ class RewooAgent(BaseAgent):
|
||||
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||
|
||||
# Solve
|
||||
solver_output = solver(instruction, worker_log)
|
||||
solver_output = self.solver(instruction, worker_log)
|
||||
solver_output_text = solver_output.text
|
||||
if use_citation:
|
||||
citation_pipeline = CitationPipeline(llm=solver_llm)
|
||||
citation_pipeline = CitationPipeline(llm=self.solver_llm)
|
||||
citation = citation_pipeline(context=worker_log, question=instruction)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
return Document(
|
||||
return AgentOutput(
|
||||
text=solver_output_text,
|
||||
metadata={
|
||||
"agent": "react",
|
||||
"cost": total_cost,
|
||||
"usage": total_token,
|
||||
"citation": citation,
|
||||
},
|
||||
agent_type=self.agent_type,
|
||||
status="finished",
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
citation=citation,
|
||||
)
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.agents.base import BaseLLM, BaseTool
|
||||
from kotaemon.agents.io import BaseScratchPad
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from ..base import BaseLLM, BaseTool
|
||||
from ..output.base import BaseScratchPad
|
||||
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
|
||||
|
||||
|
||||
|
@@ -1,10 +1,9 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.agents.io import BaseScratchPad
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from ..base import BaseLLM
|
||||
from ..output.base import BaseScratchPad
|
||||
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user