Feat/Add ReAct and ReWOO Reasoning Pipelines (#43)
* Add ReactAgentPipeline by wrapping the ReactAgent * Implement stream processing for ReactAgentPipeline and RewooAgentPipeline * Fix highlight_citation in Rewoo and remove highlight_citation from React * Fix importing ktem.llms inside kotaemon * fix: Change Rewoo::solver's output to LLMInterface instead of plain text * Add more user_settings to the RewooAgentPipeline * Fix LLMTool * Add more user_settings to the ReactAgentPipeline * Minor fix * Stream the react agent immediately * Yield the Rewoo progress to info panel * Hide the agent in flowsettings * Remove redundant comments --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
parent
ec11b54ff2
commit
466adf2d94
|
@ -253,5 +253,6 @@ class AgentOutput(LLMInterface):
|
||||||
text: str
|
text: str
|
||||||
type: str = "agent"
|
type: str = "agent"
|
||||||
agent_type: AgentType
|
agent_type: AgentType
|
||||||
status: Literal["finished", "stopped", "failed"]
|
status: Literal["thinking", "finished", "stopped", "failed"]
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
intermediate_steps: Optional[list] = None
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
from kotaemon.agents.base import BaseAgent, BaseLLM
|
from kotaemon.agents.base import BaseAgent, BaseLLM
|
||||||
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
|
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
|
||||||
from kotaemon.agents.tools import BaseTool
|
from kotaemon.agents.tools import BaseTool
|
||||||
from kotaemon.base import Param
|
from kotaemon.base import Document, Param
|
||||||
|
from kotaemon.indices.splitters import TokenSplitter
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import PromptTemplate
|
||||||
|
|
||||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||||
|
@ -22,6 +26,7 @@ class ReactAgent(BaseAgent):
|
||||||
description: str = "ReactAgent for answering multi-step reasoning questions"
|
description: str = "ReactAgent for answering multi-step reasoning questions"
|
||||||
llm: BaseLLM
|
llm: BaseLLM
|
||||||
prompt_template: Optional[PromptTemplate] = None
|
prompt_template: Optional[PromptTemplate] = None
|
||||||
|
output_lang: str = "English"
|
||||||
plugins: list[BaseTool] = Param(
|
plugins: list[BaseTool] = Param(
|
||||||
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
||||||
)
|
)
|
||||||
|
@ -32,8 +37,18 @@ class ReactAgent(BaseAgent):
|
||||||
default_callback=lambda _: [],
|
default_callback=lambda _: [],
|
||||||
help="List of AgentAction and observation (tool) output",
|
help="List of AgentAction and observation (tool) output",
|
||||||
)
|
)
|
||||||
max_iterations: int = 10
|
max_iterations: int = 5
|
||||||
strict_decode: bool = False
|
strict_decode: bool = False
|
||||||
|
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||||
|
chunk_size=800,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator=" ",
|
||||||
|
tokenizer=partial(
|
||||||
|
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||||
|
allowed_special=set(),
|
||||||
|
disallowed_special="all",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def _compose_plugin_description(self) -> str:
|
def _compose_plugin_description(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -119,6 +134,7 @@ class ReactAgent(BaseAgent):
|
||||||
agent_scratchpad=agent_scratchpad,
|
agent_scratchpad=agent_scratchpad,
|
||||||
tool_description=tool_description,
|
tool_description=tool_description,
|
||||||
tool_names=tool_names,
|
tool_names=tool_names,
|
||||||
|
lang=self.output_lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _format_function_map(self) -> dict[str, BaseTool]:
|
def _format_function_map(self) -> dict[str, BaseTool]:
|
||||||
|
@ -133,6 +149,20 @@ class ReactAgent(BaseAgent):
|
||||||
function_map[plugin.name] = plugin
|
function_map[plugin.name] = plugin
|
||||||
return function_map
|
return function_map
|
||||||
|
|
||||||
|
def _trim(self, text: str) -> str:
|
||||||
|
"""
|
||||||
|
Trim the text to the maximum token length.
|
||||||
|
"""
|
||||||
|
if isinstance(text, str):
|
||||||
|
texts = self.trim_func([Document(text=text)])
|
||||||
|
elif isinstance(text, Document):
|
||||||
|
texts = self.trim_func([text])
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid text type to trim")
|
||||||
|
trim_text = texts[0].text
|
||||||
|
logging.info(f"len (trimmed): {len(trim_text)}")
|
||||||
|
return trim_text
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""
|
"""
|
||||||
Clear and reset the agent.
|
Clear and reset the agent.
|
||||||
|
@ -183,6 +213,11 @@ class ReactAgent(BaseAgent):
|
||||||
logging.info(f"Action: {action_name}")
|
logging.info(f"Action: {action_name}")
|
||||||
logging.info(f"Tool Input: {tool_input}")
|
logging.info(f"Tool Input: {tool_input}")
|
||||||
result = self._format_function_map()[action_name](tool_input)
|
result = self._format_function_map()[action_name](tool_input)
|
||||||
|
|
||||||
|
# trim the worker output to 1000 tokens, as we are appending
|
||||||
|
# all workers' logs and it can exceed the token limit if we
|
||||||
|
# don't limit each. Fix this number regarding to the LLM capacity.
|
||||||
|
result = self._trim(result)
|
||||||
logging.info(f"Result: {result}")
|
logging.info(f"Result: {result}")
|
||||||
|
|
||||||
self.intermediate_steps.append((action_step, result))
|
self.intermediate_steps.append((action_step, result))
|
||||||
|
@ -202,3 +237,100 @@ class ReactAgent(BaseAgent):
|
||||||
intermediate_steps=self.intermediate_steps,
|
intermediate_steps=self.intermediate_steps,
|
||||||
max_iterations=max_iterations,
|
max_iterations=max_iterations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def stream(self, instruction, max_iterations=None):
|
||||||
|
"""
|
||||||
|
Stream the agent with the given instruction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instruction: Instruction to run the agent with.
|
||||||
|
max_iterations: Maximum number of iterations
|
||||||
|
of reasoning steps, defaults to 10.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
AgentOutput object.
|
||||||
|
"""
|
||||||
|
if not max_iterations:
|
||||||
|
max_iterations = self.max_iterations
|
||||||
|
assert max_iterations > 0
|
||||||
|
|
||||||
|
self.clear()
|
||||||
|
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||||
|
print(f"Running {self.name} with instruction: {instruction}")
|
||||||
|
total_cost = 0.0
|
||||||
|
total_token = 0
|
||||||
|
status = "failed"
|
||||||
|
response_text = None
|
||||||
|
|
||||||
|
for step_count in range(1, max_iterations + 1):
|
||||||
|
prompt = self._compose_prompt(instruction)
|
||||||
|
logging.info(f"Prompt: {prompt}")
|
||||||
|
print(f"Prompt: {prompt}")
|
||||||
|
response = self.llm(
|
||||||
|
prompt, stop=["Observation:"]
|
||||||
|
) # TODO: could cause bugs if llm doesn't have `stop` as a parameter
|
||||||
|
response_text = response.text
|
||||||
|
logging.info(f"Response: {response_text}")
|
||||||
|
print(f"Response: {response_text}")
|
||||||
|
action_step = self._parse_output(response_text)
|
||||||
|
if action_step is None:
|
||||||
|
raise ValueError("Invalid action")
|
||||||
|
is_finished_chain = isinstance(action_step, AgentFinish)
|
||||||
|
if is_finished_chain:
|
||||||
|
result = response_text
|
||||||
|
if "Final Answer:" in response_text:
|
||||||
|
result = response_text.split("Final Answer:")[-1].strip()
|
||||||
|
else:
|
||||||
|
assert isinstance(action_step, AgentAction)
|
||||||
|
action_name = action_step.tool
|
||||||
|
tool_input = action_step.tool_input
|
||||||
|
logging.info(f"Action: {action_name}")
|
||||||
|
print(f"Action: {action_name}")
|
||||||
|
logging.info(f"Tool Input: {tool_input}")
|
||||||
|
print(f"Tool Input: {tool_input}")
|
||||||
|
result = self._format_function_map()[action_name](tool_input)
|
||||||
|
|
||||||
|
# trim the worker output to 1000 tokens, as we are appending
|
||||||
|
# all workers' logs and it can exceed the token limit if we
|
||||||
|
# don't limit each. Fix this number regarding to the LLM capacity.
|
||||||
|
result = self._trim(result)
|
||||||
|
logging.info(f"Result: {result}")
|
||||||
|
print(f"Result: {result}")
|
||||||
|
|
||||||
|
self.intermediate_steps.append((action_step, result))
|
||||||
|
if is_finished_chain:
|
||||||
|
logging.info(f"Finished after {step_count} steps.")
|
||||||
|
status = "finished"
|
||||||
|
yield AgentOutput(
|
||||||
|
text=result,
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status=status,
|
||||||
|
intermediate_steps=self.intermediate_steps[-1],
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
yield AgentOutput(
|
||||||
|
text="",
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status="thinking",
|
||||||
|
intermediate_steps=self.intermediate_steps[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
status = "stopped"
|
||||||
|
yield AgentOutput(
|
||||||
|
text="",
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status=status,
|
||||||
|
intermediate_steps=self.intermediate_steps[-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentOutput(
|
||||||
|
text=response_text,
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status=status,
|
||||||
|
total_tokens=total_token,
|
||||||
|
total_cost=total_cost,
|
||||||
|
intermediate_steps=self.intermediate_steps,
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
)
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import PromptTemplate
|
||||||
|
|
||||||
zero_shot_react_prompt = PromptTemplate(
|
zero_shot_react_prompt = PromptTemplate(
|
||||||
template="""Answer the following questions as best you can. You have access to the following tools:
|
template="""Answer the following questions as best you can. Give answer in {lang}. You have access to the following tools:
|
||||||
{tool_description}
|
{tool_description}
|
||||||
Use the following format:
|
Use the following format:
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ Thought: you should always think about what to do
|
||||||
|
|
||||||
Action: the action to take, should be one of [{tool_names}]
|
Action: the action to take, should be one of [{tool_names}]
|
||||||
|
|
||||||
Action Input: the input to the action
|
Action Input: the input to the action, should be different from the action input of the same action in previous steps.
|
||||||
|
|
||||||
Observation: the result of the action
|
Observation: the result of the action
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
from kotaemon.agents.base import BaseAgent
|
from kotaemon.agents.base import BaseAgent
|
||||||
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
|
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
|
||||||
from kotaemon.agents.tools import BaseTool
|
from kotaemon.agents.tools import BaseTool
|
||||||
from kotaemon.agents.utils import get_plugin_response_content
|
from kotaemon.agents.utils import get_plugin_response_content
|
||||||
from kotaemon.base import Node, Param
|
from kotaemon.base import Document, Node, Param
|
||||||
from kotaemon.indices.qa import CitationPipeline
|
from kotaemon.indices.qa.citation import CitationPipeline
|
||||||
|
from kotaemon.indices.splitters import TokenSplitter
|
||||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||||
|
|
||||||
from .planner import Planner
|
from .planner import Planner
|
||||||
|
@ -22,6 +26,7 @@ class RewooAgent(BaseAgent):
|
||||||
name: str = "RewooAgent"
|
name: str = "RewooAgent"
|
||||||
agent_type: AgentType = AgentType.rewoo
|
agent_type: AgentType = AgentType.rewoo
|
||||||
description: str = "RewooAgent for answering multi-step reasoning questions"
|
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||||
|
output_lang: str = "English"
|
||||||
planner_llm: BaseLLM
|
planner_llm: BaseLLM
|
||||||
solver_llm: BaseLLM
|
solver_llm: BaseLLM
|
||||||
prompt_template: dict[str, PromptTemplate] = Param(
|
prompt_template: dict[str, PromptTemplate] = Param(
|
||||||
|
@ -34,6 +39,16 @@ class RewooAgent(BaseAgent):
|
||||||
examples: dict[str, str | list[str]] = Param(
|
examples: dict[str, str | list[str]] = Param(
|
||||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||||
)
|
)
|
||||||
|
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||||
|
chunk_size=3000,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator=" ",
|
||||||
|
tokenizer=partial(
|
||||||
|
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||||
|
allowed_special=set(),
|
||||||
|
disallowed_special="all",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
||||||
def planner(self):
|
def planner(self):
|
||||||
|
@ -50,6 +65,7 @@ class RewooAgent(BaseAgent):
|
||||||
model=self.solver_llm,
|
model=self.solver_llm,
|
||||||
prompt_template=self.prompt_template.get("Solver", None),
|
prompt_template=self.prompt_template.get("Solver", None),
|
||||||
examples=self.examples.get("Solver", None),
|
examples=self.examples.get("Solver", None),
|
||||||
|
output_lang=self.output_lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_plan_map(
|
def _parse_plan_map(
|
||||||
|
@ -159,8 +175,13 @@ class RewooAgent(BaseAgent):
|
||||||
tool_input = tool_input[:-1]
|
tool_input = tool_input[:-1]
|
||||||
# find variables in input and replace with previous evidences
|
# find variables in input and replace with previous evidences
|
||||||
for var in re.findall(r"#E\d+", tool_input):
|
for var in re.findall(r"#E\d+", tool_input):
|
||||||
|
print("Tool input: ", tool_input)
|
||||||
|
print("Var: ", var)
|
||||||
|
print("Worker evidences: ", worker_evidences)
|
||||||
if var in worker_evidences:
|
if var in worker_evidences:
|
||||||
tool_input = tool_input.replace(var, worker_evidences.get(var, ""))
|
tool_input = tool_input.replace(
|
||||||
|
var, worker_evidences.get(var, "") or ""
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
selected_plugin = self._find_plugin(tool)
|
selected_plugin = self._find_plugin(tool)
|
||||||
if selected_plugin is None:
|
if selected_plugin is None:
|
||||||
|
@ -216,7 +237,7 @@ class RewooAgent(BaseAgent):
|
||||||
resp = r.result()
|
resp = r.result()
|
||||||
plugin_cost += resp["plugin_cost"]
|
plugin_cost += resp["plugin_cost"]
|
||||||
plugin_token += resp["plugin_token"]
|
plugin_token += resp["plugin_token"]
|
||||||
worker_evidences[resp["e"]] = resp["evidence"]
|
worker_evidences[resp["e"]] = self._trim_evidence(resp["evidence"])
|
||||||
output.done()
|
output.done()
|
||||||
|
|
||||||
return worker_evidences, plugin_cost, plugin_token
|
return worker_evidences, plugin_cost, plugin_token
|
||||||
|
@ -226,6 +247,13 @@ class RewooAgent(BaseAgent):
|
||||||
if p.name == name:
|
if p.name == name:
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
def _trim_evidence(self, evidence: str):
|
||||||
|
if evidence:
|
||||||
|
texts = self.trim_func([Document(text=evidence)])
|
||||||
|
evidence = texts[0].text
|
||||||
|
logging.info(f"len (trimmed): {len(evidence)}")
|
||||||
|
return evidence
|
||||||
|
|
||||||
@BaseAgent.safeguard_run
|
@BaseAgent.safeguard_run
|
||||||
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
|
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
|
||||||
"""
|
"""
|
||||||
|
@ -269,5 +297,69 @@ class RewooAgent(BaseAgent):
|
||||||
total_tokens=total_token,
|
total_tokens=total_token,
|
||||||
total_cost=total_cost,
|
total_cost=total_cost,
|
||||||
citation=citation,
|
citation=citation,
|
||||||
metadata={"citation": citation},
|
metadata={"citation": citation, "worker_log": worker_log},
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream(self, instruction: str, use_citation: bool = False):
|
||||||
|
"""
|
||||||
|
Stream the agent with a given instruction.
|
||||||
|
"""
|
||||||
|
logging.info(f"Streaming {self.name} with instruction: {instruction}")
|
||||||
|
total_cost = 0.0
|
||||||
|
total_token = 0
|
||||||
|
|
||||||
|
# Plan
|
||||||
|
planner_output = self.planner(instruction)
|
||||||
|
planner_text_output = planner_output.text
|
||||||
|
plan_to_es, plans = self._parse_plan_map(planner_text_output)
|
||||||
|
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||||
|
planner_text_output
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Planner output:", planner_text_output)
|
||||||
|
# Work
|
||||||
|
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
|
||||||
|
planner_evidences, evidence_level
|
||||||
|
)
|
||||||
|
worker_log = ""
|
||||||
|
for plan in plan_to_es:
|
||||||
|
worker_log += f"{plan}: {plans[plan]}\n"
|
||||||
|
current_progress = f"{plan}: {plans[plan]}\n"
|
||||||
|
for e in plan_to_es[plan]:
|
||||||
|
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||||
|
current_progress += f"{e}: {worker_evidences[e]}\n"
|
||||||
|
|
||||||
|
yield AgentOutput(
|
||||||
|
text="",
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status="thinking",
|
||||||
|
intermediate_steps=[{"worker_log": current_progress}],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Solve
|
||||||
|
solver_response = ""
|
||||||
|
for solver_output in self.solver.stream(instruction, worker_log):
|
||||||
|
solver_output_text = solver_output.text
|
||||||
|
solver_response += solver_output_text
|
||||||
|
yield AgentOutput(
|
||||||
|
text=solver_output_text,
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status="thinking",
|
||||||
|
)
|
||||||
|
if use_citation:
|
||||||
|
citation_pipeline = CitationPipeline(llm=self.solver_llm)
|
||||||
|
citation = citation_pipeline.invoke(
|
||||||
|
context=worker_log, question=instruction
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
citation = None
|
||||||
|
|
||||||
|
return AgentOutput(
|
||||||
|
text="",
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status="finished",
|
||||||
|
total_tokens=total_token,
|
||||||
|
total_cost=total_cost,
|
||||||
|
citation=citation,
|
||||||
|
metadata={"citation": citation, "worker_log": worker_log},
|
||||||
)
|
)
|
||||||
|
|
|
@ -81,3 +81,26 @@ class Planner(BaseComponent):
|
||||||
raise ValueError("Planner failed to retrieve response from LLM") from e
|
raise ValueError("Planner failed to retrieve response from LLM") from e
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def stream(self, instruction: str, output: BaseScratchPad = BaseScratchPad()):
|
||||||
|
response = None
|
||||||
|
output.info("Running Planner")
|
||||||
|
prompt = self._compose_prompt(instruction)
|
||||||
|
output.debug(f"Prompt: {prompt}")
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
try:
|
||||||
|
for text in self.model.stream(prompt):
|
||||||
|
response += text
|
||||||
|
yield text
|
||||||
|
self.log_progress(".planner", response=response)
|
||||||
|
output.info("Planner run successful.")
|
||||||
|
except NotImplementedError:
|
||||||
|
print("Streaming is not supported, falling back to normal run")
|
||||||
|
response = self.model(prompt)
|
||||||
|
yield response
|
||||||
|
except ValueError as e:
|
||||||
|
output.error("Planner failed to retrieve response from LLM")
|
||||||
|
raise ValueError("Planner failed to retrieve response from LLM") from e
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
|
@ -81,7 +81,7 @@ And so on...
|
||||||
|
|
||||||
zero_shot_solver_prompt = PromptTemplate(
|
zero_shot_solver_prompt = PromptTemplate(
|
||||||
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
|
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
|
||||||
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
Your task is to briefly summarize each step, then make a short final conclusion for your task. Give answer in {lang}.
|
||||||
|
|
||||||
##My Plans and Evidences##
|
##My Plans and Evidences##
|
||||||
{plan_evidence}
|
{plan_evidence}
|
||||||
|
@ -99,7 +99,7 @@ So, <your conclusion>.
|
||||||
|
|
||||||
few_shot_solver_prompt = PromptTemplate(
|
few_shot_solver_prompt = PromptTemplate(
|
||||||
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
|
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
|
||||||
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
Your task is to briefly summarize each step, then make a short final conclusion for your task. Give answer in {lang}.
|
||||||
|
|
||||||
##My Plans and Evidences##
|
##My Plans and Evidences##
|
||||||
{plan_evidence}
|
{plan_evidence}
|
||||||
|
|
|
@ -11,6 +11,7 @@ class Solver(BaseComponent):
|
||||||
model: BaseLLM
|
model: BaseLLM
|
||||||
prompt_template: Optional[PromptTemplate] = None
|
prompt_template: Optional[PromptTemplate] = None
|
||||||
examples: Optional[Union[str, List[str]]] = None
|
examples: Optional[Union[str, List[str]]] = None
|
||||||
|
output_lang: str = "English"
|
||||||
|
|
||||||
def _compose_fewshot_prompt(self) -> str:
|
def _compose_fewshot_prompt(self) -> str:
|
||||||
if self.examples is None:
|
if self.examples is None:
|
||||||
|
@ -20,7 +21,7 @@ class Solver(BaseComponent):
|
||||||
else:
|
else:
|
||||||
return "\n\n".join([e.strip("\n") for e in self.examples])
|
return "\n\n".join([e.strip("\n") for e in self.examples])
|
||||||
|
|
||||||
def _compose_prompt(self, instruction, plan_evidence) -> str:
|
def _compose_prompt(self, instruction, plan_evidence, output_lang) -> str:
|
||||||
"""
|
"""
|
||||||
Compose the prompt from template, plan&evidence, examples and instruction.
|
Compose the prompt from template, plan&evidence, examples and instruction.
|
||||||
"""
|
"""
|
||||||
|
@ -28,20 +29,28 @@ class Solver(BaseComponent):
|
||||||
if self.prompt_template is not None:
|
if self.prompt_template is not None:
|
||||||
if "fewshot" in self.prompt_template.placeholders:
|
if "fewshot" in self.prompt_template.placeholders:
|
||||||
return self.prompt_template.populate(
|
return self.prompt_template.populate(
|
||||||
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
plan_evidence=plan_evidence,
|
||||||
|
fewshot=fewshot,
|
||||||
|
task=instruction,
|
||||||
|
lang=output_lang,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.prompt_template.populate(
|
return self.prompt_template.populate(
|
||||||
plan_evidence=plan_evidence, task=instruction
|
plan_evidence=plan_evidence, task=instruction, lang=output_lang
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.examples is not None:
|
if self.examples is not None:
|
||||||
return few_shot_solver_prompt.populate(
|
return few_shot_solver_prompt.populate(
|
||||||
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
plan_evidence=plan_evidence,
|
||||||
|
fewshot=fewshot,
|
||||||
|
task=instruction,
|
||||||
|
lang=output_lang,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return zero_shot_solver_prompt.populate(
|
return zero_shot_solver_prompt.populate(
|
||||||
plan_evidence=plan_evidence, task=instruction
|
plan_evidence=plan_evidence,
|
||||||
|
task=instruction,
|
||||||
|
lang=output_lang,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
|
@ -54,7 +63,7 @@ class Solver(BaseComponent):
|
||||||
output.info("Running Solver")
|
output.info("Running Solver")
|
||||||
output.debug(f"Instruction: {instruction}")
|
output.debug(f"Instruction: {instruction}")
|
||||||
output.debug(f"Plan Evidence: {plan_evidence}")
|
output.debug(f"Plan Evidence: {plan_evidence}")
|
||||||
prompt = self._compose_prompt(instruction, plan_evidence)
|
prompt = self._compose_prompt(instruction, plan_evidence, self.output_lang)
|
||||||
output.debug(f"Prompt: {prompt}")
|
output.debug(f"Prompt: {prompt}")
|
||||||
try:
|
try:
|
||||||
response = self.model(prompt)
|
response = self.model(prompt)
|
||||||
|
@ -63,3 +72,28 @@ class Solver(BaseComponent):
|
||||||
output.error("Solver failed to retrieve response from LLM")
|
output.error("Solver failed to retrieve response from LLM")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self,
|
||||||
|
instruction: str,
|
||||||
|
plan_evidence: str,
|
||||||
|
output: BaseScratchPad = BaseScratchPad(),
|
||||||
|
) -> Any:
|
||||||
|
response = ""
|
||||||
|
output.info("Running Solver")
|
||||||
|
output.debug(f"Instruction: {instruction}")
|
||||||
|
output.debug(f"Plan Evidence: {plan_evidence}")
|
||||||
|
prompt = self._compose_prompt(instruction, plan_evidence, self.output_lang)
|
||||||
|
output.debug(f"Prompt: {prompt}")
|
||||||
|
try:
|
||||||
|
for text in self.model.stream(prompt):
|
||||||
|
response += text.text
|
||||||
|
yield text
|
||||||
|
output.info("Planner run successful.")
|
||||||
|
except NotImplementedError:
|
||||||
|
response = self.model(prompt).text
|
||||||
|
output.info("Solver run successful.")
|
||||||
|
except ValueError:
|
||||||
|
output.error("Solver failed to retrieve response from LLM")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from typing import AnyStr, Optional, Type
|
from typing import AnyStr, Optional, Type
|
||||||
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
from langchain.utilities import SerpAPIWrapper
|
from langchain.utilities import SerpAPIWrapper
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
@ -26,12 +27,17 @@ class GoogleSearchTool(BaseTool):
|
||||||
"install googlesearch using `pip3 install googlesearch-python` to "
|
"install googlesearch using `pip3 install googlesearch-python` to "
|
||||||
"use this tool"
|
"use this tool"
|
||||||
)
|
)
|
||||||
output = ""
|
|
||||||
search_results = search(query, advanced=True)
|
try:
|
||||||
if search_results:
|
output = ""
|
||||||
output = "\n".join(
|
search_results = search(query, advanced=True)
|
||||||
"{} {}".format(item.title, item.description) for item in search_results
|
if search_results:
|
||||||
)
|
output = "\n".join(
|
||||||
|
"{} {}".format(item.title, item.description)
|
||||||
|
for item in search_results
|
||||||
|
)
|
||||||
|
except HTTPError:
|
||||||
|
output = "No evidence found."
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,10 @@ from typing import AnyStr, Optional, Type
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from kotaemon.agents.tools.base import ToolException
|
||||||
from kotaemon.llms import BaseLLM
|
from kotaemon.llms import BaseLLM
|
||||||
|
|
||||||
from .base import BaseTool, ToolException
|
from .base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
class LLMArgs(BaseModel):
|
class LLMArgs(BaseModel):
|
||||||
|
|
329
libs/ktem/ktem/reasoning/react.py
Normal file
329
libs/ktem/ktem/reasoning/react.py
Normal file
|
@ -0,0 +1,329 @@
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
from typing import AnyStr, Optional, Type
|
||||||
|
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
from ktem.reasoning.base import BaseReasoning
|
||||||
|
from ktem.utils.generator import Generator
|
||||||
|
from ktem.utils.render import Render
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from kotaemon.agents import (
|
||||||
|
BaseTool,
|
||||||
|
GoogleSearchTool,
|
||||||
|
LLMTool,
|
||||||
|
ReactAgent,
|
||||||
|
WikipediaTool,
|
||||||
|
)
|
||||||
|
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
|
||||||
|
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DocSearchArgs(BaseModel):
|
||||||
|
query: str = Field(..., description="a search query as input to the doc search")
|
||||||
|
|
||||||
|
|
||||||
|
class DocSearchTool(BaseTool):
|
||||||
|
name: str = "docsearch"
|
||||||
|
description: str = (
|
||||||
|
"A storage that contains internal documents. If you lack any specific "
|
||||||
|
"private information to answer the question, you can search in this "
|
||||||
|
"document storage. Furthermore, if you are unsure about which document that "
|
||||||
|
"the user refers to, likely the user already selects the target document in "
|
||||||
|
"this document storage, you just need to do normal search. If possible, "
|
||||||
|
"formulate the search query as specific as possible."
|
||||||
|
)
|
||||||
|
args_schema: Optional[Type[BaseModel]] = DocSearchArgs
|
||||||
|
retrievers: list[BaseComponent] = []
|
||||||
|
|
||||||
|
def _run_tool(self, query: AnyStr) -> AnyStr:
|
||||||
|
docs = []
|
||||||
|
doc_ids = []
|
||||||
|
for retriever in self.retrievers:
|
||||||
|
for doc in retriever(text=query):
|
||||||
|
if doc.doc_id not in doc_ids:
|
||||||
|
docs.append(doc)
|
||||||
|
doc_ids.append(doc.doc_id)
|
||||||
|
|
||||||
|
return self.prepare_evidence(docs)
|
||||||
|
|
||||||
|
def prepare_evidence(self, docs, trim_len: int = 4000):
|
||||||
|
evidence = ""
|
||||||
|
table_found = 0
|
||||||
|
|
||||||
|
for _id, retrieved_item in enumerate(docs):
|
||||||
|
retrieved_content = ""
|
||||||
|
page = retrieved_item.metadata.get("page_label", None)
|
||||||
|
source = filename = retrieved_item.metadata.get("file_name", "-")
|
||||||
|
if page:
|
||||||
|
source += f" (Page {page})"
|
||||||
|
if retrieved_item.metadata.get("type", "") == "table":
|
||||||
|
if table_found < 5:
|
||||||
|
retrieved_content = retrieved_item.metadata.get("table_origin", "")
|
||||||
|
if retrieved_content not in evidence:
|
||||||
|
table_found += 1
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Table from {source}</b>\n"
|
||||||
|
+ retrieved_content
|
||||||
|
+ "\n<br>"
|
||||||
|
)
|
||||||
|
elif retrieved_item.metadata.get("type", "") == "chatbot":
|
||||||
|
retrieved_content = retrieved_item.metadata["window"]
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
|
||||||
|
+ retrieved_content
|
||||||
|
+ "\n<br>"
|
||||||
|
)
|
||||||
|
elif retrieved_item.metadata.get("type", "") == "image":
|
||||||
|
retrieved_content = retrieved_item.metadata.get("image_origin", "")
|
||||||
|
retrieved_caption = html.escape(retrieved_item.get_content())
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Figure from {source}</b>\n" + retrieved_caption + "\n<br>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "window" in retrieved_item.metadata:
|
||||||
|
retrieved_content = retrieved_item.metadata["window"]
|
||||||
|
else:
|
||||||
|
retrieved_content = retrieved_item.text
|
||||||
|
retrieved_content = retrieved_content.replace("\n", " ")
|
||||||
|
if retrieved_content not in evidence:
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Content from {source}: </b> "
|
||||||
|
+ retrieved_content
|
||||||
|
+ " \n<br>"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Retrieved #{}: {}".format(_id, retrieved_content[:100]))
|
||||||
|
print("Score", retrieved_item.metadata.get("relevance_score", None))
|
||||||
|
|
||||||
|
# trim context by trim_len
|
||||||
|
if evidence:
|
||||||
|
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
|
||||||
|
chunk_size=trim_len,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator=" ",
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
texts = text_splitter.split_text(evidence)
|
||||||
|
evidence = texts[0]
|
||||||
|
|
||||||
|
return Document(content=evidence)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_REGISTRY = {
|
||||||
|
"Google": GoogleSearchTool(),
|
||||||
|
"Wikipedia": WikipediaTool(),
|
||||||
|
"LLM": LLMTool(),
|
||||||
|
"SearchDoc": DocSearchTool(),
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_QA_PROMPT = (
|
||||||
|
"Answer the following questions as best you can. Give answer in {lang}. "
|
||||||
|
"You have access to the following tools:\n"
|
||||||
|
"{tool_description}\n"
|
||||||
|
"Use the following format:\n\n"
|
||||||
|
"Question: the input question you must answer\n"
|
||||||
|
"Thought: you should always think about what to do\n\n"
|
||||||
|
"Action: the action to take, should be one of [{tool_names}]\n\n"
|
||||||
|
"Action Input: the input to the action, should be different from the action input "
|
||||||
|
"of the same action in previous steps.\n\n"
|
||||||
|
"Observation: the result of the action\n\n"
|
||||||
|
"... (this Thought/Action/Action Input/Observation can repeat N times)\n"
|
||||||
|
"#Thought: I now know the final answer\n"
|
||||||
|
"Final Answer: the final answer to the original input question\n\n"
|
||||||
|
"Begin! After each Action Input.\n\n"
|
||||||
|
"Question: {instruction}\n"
|
||||||
|
"Thought: {agent_scratchpad}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_REWRITE_PROMPT = (
|
||||||
|
"Given the following question, rephrase and expand it "
|
||||||
|
"to help you do better answering. Maintain all information "
|
||||||
|
"in the original question. Keep the question as concise as possible. "
|
||||||
|
"Give answer in {lang}\n"
|
||||||
|
"Original question: {question}\n"
|
||||||
|
"Rephrased question: "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RewriteQuestionPipeline(BaseComponent):
|
||||||
|
"""Rewrite user question
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: the language model to rewrite question
|
||||||
|
rewrite_template: the prompt template for llm to paraphrase a text input
|
||||||
|
lang: the language of the answer. Currently support English and Japanese
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||||
|
rewrite_template: str = DEFAULT_REWRITE_PROMPT
|
||||||
|
|
||||||
|
lang: str = "English"
|
||||||
|
|
||||||
|
def run(self, question: str) -> Document: # type: ignore
|
||||||
|
prompt_template = PromptTemplate(self.rewrite_template)
|
||||||
|
prompt = prompt_template.populate(question=question, lang=self.lang)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
|
HumanMessage(content=prompt),
|
||||||
|
]
|
||||||
|
return self.llm(messages)
|
||||||
|
|
||||||
|
|
||||||
|
class ReactAgentPipeline(BaseReasoning):
|
||||||
|
"""Question answering pipeline using ReAct agent."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_extra = True
|
||||||
|
|
||||||
|
retrievers: list[BaseComponent]
|
||||||
|
agent: ReactAgent = ReactAgent.withx()
|
||||||
|
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
||||||
|
use_rewrite: bool = False
|
||||||
|
|
||||||
|
def prepare_citation(self, step_id, step, output, status) -> Document:
|
||||||
|
header = "<b>Step {id}</b>: {log}".format(id=step_id, log=step.log)
|
||||||
|
content = (
|
||||||
|
"<b>Action</b>: <em>{tool}[{input}]</em>\n\n<b>Output</b>: {output}"
|
||||||
|
).format(
|
||||||
|
tool=step.tool if status == "thinking" else "",
|
||||||
|
input=step.tool_input.replace("\n", "") if status == "thinking" else "",
|
||||||
|
output=output if status == "thinking" else "Finished",
|
||||||
|
)
|
||||||
|
return Document(
|
||||||
|
channel="info",
|
||||||
|
content=Render.collapsible(
|
||||||
|
header=header,
|
||||||
|
content=Render.table(content),
|
||||||
|
open=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ainvoke( # type: ignore
|
||||||
|
self, message, conv_id: str, history: list, **kwargs # type: ignore
|
||||||
|
) -> Document:
|
||||||
|
if self.use_rewrite:
|
||||||
|
rewrite = await self.rewrite_pipeline(question=message)
|
||||||
|
message = rewrite.text
|
||||||
|
|
||||||
|
answer = self.agent(message)
|
||||||
|
self.report_output(Document(content=answer.text, channel="chat"))
|
||||||
|
|
||||||
|
intermediate_steps = answer.intermediate_steps
|
||||||
|
for _, step_output in intermediate_steps:
|
||||||
|
self.report_output(Document(content=step_output, channel="info"))
|
||||||
|
|
||||||
|
self.report_output(None)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def stream(self, message, conv_id: str, history: list, **kwargs):
|
||||||
|
if self.use_rewrite:
|
||||||
|
rewrite = self.rewrite_pipeline(question=message)
|
||||||
|
message = rewrite.text
|
||||||
|
yield Document(
|
||||||
|
channel="info",
|
||||||
|
content=f"Rewrote the message to: {rewrite.text}",
|
||||||
|
)
|
||||||
|
|
||||||
|
output_stream = Generator(self.agent.stream(message))
|
||||||
|
idx = 0
|
||||||
|
for item in output_stream:
|
||||||
|
idx += 1
|
||||||
|
if item.status == "thinking":
|
||||||
|
step, step_output = item.intermediate_steps
|
||||||
|
yield Document(
|
||||||
|
channel="info",
|
||||||
|
content=self.prepare_citation(idx, step, step_output, item.status),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield Document(
|
||||||
|
channel="chat",
|
||||||
|
content=item.text,
|
||||||
|
)
|
||||||
|
step, step_output = item.intermediate_steps
|
||||||
|
yield Document(
|
||||||
|
channel="info",
|
||||||
|
content=self.prepare_citation(idx, step, step_output, item.status),
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_stream.value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline(
|
||||||
|
cls, settings: dict, states: dict, retrievers: list | None = None
|
||||||
|
) -> BaseReasoning:
|
||||||
|
_id = cls.get_info()["id"]
|
||||||
|
prefix = f"reasoning.options.{_id}"
|
||||||
|
|
||||||
|
llm_name = settings[f"{prefix}.llm"]
|
||||||
|
llm = llms.get(llm_name, llms.get_default())
|
||||||
|
|
||||||
|
pipeline = ReactAgentPipeline(retrievers=retrievers)
|
||||||
|
pipeline.agent.llm = llm
|
||||||
|
pipeline.agent.max_iterations = settings[f"{prefix}.max_iterations"]
|
||||||
|
tools = []
|
||||||
|
for tool_name in settings[f"reasoning.options.{_id}.tools"]:
|
||||||
|
tool = TOOL_REGISTRY[tool_name]
|
||||||
|
if tool_name == "SearchDoc":
|
||||||
|
tool.retrievers = retrievers
|
||||||
|
elif tool_name == "LLM":
|
||||||
|
tool.llm = llm
|
||||||
|
tools.append(tool)
|
||||||
|
pipeline.agent.plugins = tools
|
||||||
|
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
|
||||||
|
settings["reasoning.lang"], "English"
|
||||||
|
)
|
||||||
|
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||||
|
pipeline.agent.prompt_template = PromptTemplate(settings[f"{prefix}.qa_prompt"])
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_settings(cls) -> dict:
|
||||||
|
llm = ""
|
||||||
|
llm_choices = [("(default)", "")]
|
||||||
|
try:
|
||||||
|
llm_choices += [(_, _) for _ in llms.options().keys()]
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to get LLM options: {e}")
|
||||||
|
|
||||||
|
tool_choices = ["Wikipedia", "Google", "LLM", "SearchDoc"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"llm": {
|
||||||
|
"name": "Language model",
|
||||||
|
"value": llm,
|
||||||
|
"component": "dropdown",
|
||||||
|
"choices": llm_choices,
|
||||||
|
"info": (
|
||||||
|
"The language model to use for generating the answer. If None, "
|
||||||
|
"the application default language model will be used."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"name": "Tools for knowledge retrieval",
|
||||||
|
"value": ["SearchDoc", "LLM"],
|
||||||
|
"component": "checkboxgroup",
|
||||||
|
"choices": tool_choices,
|
||||||
|
},
|
||||||
|
"max_iterations": {
|
||||||
|
"name": "Maximum number of iterations the LLM can go through",
|
||||||
|
"value": 5,
|
||||||
|
"component": "number",
|
||||||
|
},
|
||||||
|
"qa_prompt": {
|
||||||
|
"name": "QA Prompt",
|
||||||
|
"value": DEFAULT_QA_PROMPT,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_info(cls) -> dict:
|
||||||
|
return {
|
||||||
|
"id": "ReAct",
|
||||||
|
"name": "ReAct Agent",
|
||||||
|
"description": "Implementing ReAct paradigm",
|
||||||
|
}
|
462
libs/ktem/ktem/reasoning/rewoo.py
Normal file
462
libs/ktem/ktem/reasoning/rewoo.py
Normal file
|
@ -0,0 +1,462 @@
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
from typing import AnyStr, Generator, Optional, Type
|
||||||
|
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
from ktem.reasoning.base import BaseReasoning
|
||||||
|
from ktem.utils.generator import Generator as GeneratorWrapper
|
||||||
|
from ktem.utils.render import Render
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from kotaemon.agents import (
|
||||||
|
BaseTool,
|
||||||
|
GoogleSearchTool,
|
||||||
|
LLMTool,
|
||||||
|
RewooAgent,
|
||||||
|
WikipediaTool,
|
||||||
|
)
|
||||||
|
from kotaemon.base import BaseComponent, Document, HumanMessage, Node, SystemMessage
|
||||||
|
from kotaemon.llms import ChatLLM, PromptTemplate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PLANNER_PROMPT = (
|
||||||
|
"You are an AI agent who makes step-by-step plans to solve a problem under the "
|
||||||
|
"help of external tools. For each step, make one plan followed by one tool-call, "
|
||||||
|
"which will be executed later to retrieve evidence for that step.\n"
|
||||||
|
"You should store each evidence into a distinct variable #E1, #E2, #E3 ... that "
|
||||||
|
"can be referred to in later tool-call inputs.\n\n"
|
||||||
|
"##Available Tools##\n"
|
||||||
|
"{tool_description}\n\n"
|
||||||
|
"##Output Format (Replace '<...>')##\n"
|
||||||
|
"#Plan1: <describe your plan here>\n"
|
||||||
|
"#E1: <toolname>[<input here>] (eg. Search[What is Python])\n"
|
||||||
|
"#Plan2: <describe next plan>\n"
|
||||||
|
"#E2: <toolname>[<input here, you can use #E1 to represent its expected output>]\n"
|
||||||
|
"And so on...\n\n"
|
||||||
|
"##Your Task##\n"
|
||||||
|
"{task}\n\n"
|
||||||
|
"##Now Begin##\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_SOLVER_PROMPT = (
|
||||||
|
"You are an AI agent who solves a problem with my assistance. I will provide "
|
||||||
|
"step-by-step plans(#Plan) and evidences(#E) that could be helpful.\n"
|
||||||
|
"Your task is to briefly summarize each step, then make a short final conclusion "
|
||||||
|
"for your task. Give answer in {lang}.\n\n"
|
||||||
|
"##My Plans and Evidences##\n"
|
||||||
|
"{plan_evidence}\n\n"
|
||||||
|
"##Example Output##\n"
|
||||||
|
"First, I <did something> , and I think <...>; Second, I <...>, "
|
||||||
|
"and I think <...>; ....\n"
|
||||||
|
"So, <your conclusion>.\n\n"
|
||||||
|
"##Your Task##\n"
|
||||||
|
"{task}\n\n"
|
||||||
|
"##Now Begin##\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DocSearchArgs(BaseModel):
|
||||||
|
query: str = Field(..., description="a search query as input to the doc search")
|
||||||
|
|
||||||
|
|
||||||
|
class DocSearchTool(BaseTool):
|
||||||
|
name: str = "docsearch"
|
||||||
|
description: str = (
|
||||||
|
"A storage that contains internal documents. If you lack any specific "
|
||||||
|
"private information to answer the question, you can search in this "
|
||||||
|
"document storage. Furthermore, if you are unsure about which document that "
|
||||||
|
"the user refers to, likely the user already selects the target document in "
|
||||||
|
"this document storage, you just need to do normal search. If possible, "
|
||||||
|
"formulate the search query as specific as possible."
|
||||||
|
)
|
||||||
|
args_schema: Optional[Type[BaseModel]] = DocSearchArgs
|
||||||
|
retrievers: list[BaseComponent] = []
|
||||||
|
|
||||||
|
def _run_tool(self, query: AnyStr) -> AnyStr:
|
||||||
|
docs = []
|
||||||
|
doc_ids = []
|
||||||
|
for retriever in self.retrievers:
|
||||||
|
for doc in retriever(text=query):
|
||||||
|
if doc.doc_id not in doc_ids:
|
||||||
|
docs.append(doc)
|
||||||
|
doc_ids.append(doc.doc_id)
|
||||||
|
|
||||||
|
return self.prepare_evidence(docs)
|
||||||
|
|
||||||
|
def prepare_evidence(self, docs, trim_len: int = 3000):
|
||||||
|
evidence = ""
|
||||||
|
table_found = 0
|
||||||
|
|
||||||
|
for _id, retrieved_item in enumerate(docs):
|
||||||
|
retrieved_content = ""
|
||||||
|
page = retrieved_item.metadata.get("page_label", None)
|
||||||
|
source = filename = retrieved_item.metadata.get("file_name", "-")
|
||||||
|
if page:
|
||||||
|
source += f" (Page {page})"
|
||||||
|
if retrieved_item.metadata.get("type", "") == "table":
|
||||||
|
if table_found < 5:
|
||||||
|
retrieved_content = retrieved_item.metadata.get("table_origin", "")
|
||||||
|
if retrieved_content not in evidence:
|
||||||
|
table_found += 1
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Table from {source}</b>\n"
|
||||||
|
+ retrieved_content
|
||||||
|
+ "\n<br>"
|
||||||
|
)
|
||||||
|
elif retrieved_item.metadata.get("type", "") == "chatbot":
|
||||||
|
retrieved_content = retrieved_item.metadata["window"]
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Chatbot scenario from {filename} (Row {page})</b>\n"
|
||||||
|
+ retrieved_content
|
||||||
|
+ "\n<br>"
|
||||||
|
)
|
||||||
|
elif retrieved_item.metadata.get("type", "") == "image":
|
||||||
|
retrieved_content = retrieved_item.metadata.get("image_origin", "")
|
||||||
|
retrieved_caption = html.escape(retrieved_item.get_content())
|
||||||
|
# PWS doesn't support VLM for images, we will just store the caption
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Figure from {source}</b>\n" + retrieved_caption + "\n<br>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "window" in retrieved_item.metadata:
|
||||||
|
retrieved_content = retrieved_item.metadata["window"]
|
||||||
|
else:
|
||||||
|
retrieved_content = retrieved_item.text
|
||||||
|
retrieved_content = retrieved_content.replace("\n", " ")
|
||||||
|
if retrieved_content not in evidence:
|
||||||
|
evidence += (
|
||||||
|
f"<br><b>Content from {source}: </b> "
|
||||||
|
+ retrieved_content
|
||||||
|
+ " \n<br>"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Retrieved #{}: {}".format(_id, retrieved_content))
|
||||||
|
print("Score", retrieved_item.metadata.get("relevance_score", None))
|
||||||
|
|
||||||
|
# trim context by trim_len
|
||||||
|
if evidence:
|
||||||
|
text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
|
||||||
|
chunk_size=trim_len,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator=" ",
|
||||||
|
model_name="gpt-3.5-turbo",
|
||||||
|
)
|
||||||
|
texts = text_splitter.split_text(evidence)
|
||||||
|
evidence = texts[0]
|
||||||
|
|
||||||
|
return Document(content=evidence)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_REGISTRY = {
|
||||||
|
"Google": GoogleSearchTool(),
|
||||||
|
"Wikipedia": WikipediaTool(),
|
||||||
|
"LLM": LLMTool(),
|
||||||
|
"SearchDoc": DocSearchTool(),
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_REWRITE_PROMPT = (
|
||||||
|
"Given the following question, rephrase and expand it "
|
||||||
|
"to help you do better answering. Maintain all information "
|
||||||
|
"in the original question. Keep the question as concise as possible. "
|
||||||
|
"Give answer in {lang}\n"
|
||||||
|
"Original question: {question}\n"
|
||||||
|
"Rephrased question: "
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RewriteQuestionPipeline(BaseComponent):
|
||||||
|
"""Rewrite user question
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: the language model to rewrite question
|
||||||
|
rewrite_template: the prompt template for llm to paraphrase a text input
|
||||||
|
lang: the language of the answer. Currently support English and Japanese
|
||||||
|
"""
|
||||||
|
|
||||||
|
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||||
|
rewrite_template: str = DEFAULT_REWRITE_PROMPT
|
||||||
|
|
||||||
|
lang: str = "English"
|
||||||
|
|
||||||
|
def run(self, question: str) -> Document: # type: ignore
|
||||||
|
prompt_template = PromptTemplate(self.rewrite_template)
|
||||||
|
prompt = prompt_template.populate(question=question, lang=self.lang)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
|
HumanMessage(content=prompt),
|
||||||
|
]
|
||||||
|
return self.llm(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def find_text(llm_output, context):
|
||||||
|
sentence_list = llm_output.split("\n")
|
||||||
|
matches = []
|
||||||
|
for sentence in sentence_list:
|
||||||
|
match = SequenceMatcher(
|
||||||
|
None, sentence, context, autojunk=False
|
||||||
|
).find_longest_match()
|
||||||
|
matches.append((match.b, match.b + match.size))
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
class RewooAgentPipeline(BaseReasoning):
|
||||||
|
"""Question answering pipeline using ReWOO Agent."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
allow_extra = True
|
||||||
|
|
||||||
|
retrievers: list[BaseComponent]
|
||||||
|
agent: RewooAgent = RewooAgent.withx()
|
||||||
|
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
||||||
|
use_rewrite: bool = False
|
||||||
|
enable_citation: bool = False
|
||||||
|
|
||||||
|
def format_info_panel(self, worker_log):
|
||||||
|
header = ""
|
||||||
|
content = []
|
||||||
|
|
||||||
|
for line in worker_log.splitlines():
|
||||||
|
if line.startswith("#Plan"):
|
||||||
|
# line starts with #Plan should be marked as a new segment
|
||||||
|
header = line
|
||||||
|
elif line.startswith("#"):
|
||||||
|
# stop markdown from rendering big headers
|
||||||
|
line = "\\" + line
|
||||||
|
content.append(line)
|
||||||
|
else:
|
||||||
|
content.append(line)
|
||||||
|
|
||||||
|
if not header:
|
||||||
|
return
|
||||||
|
|
||||||
|
return Document(
|
||||||
|
channel="info",
|
||||||
|
content=Render.collapsible(
|
||||||
|
header=header,
|
||||||
|
content=Render.table("\n".join(content)),
|
||||||
|
open=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_citation(self, answer) -> list[Document]:
|
||||||
|
"""Prepare citation to show on the UI"""
|
||||||
|
segments = []
|
||||||
|
split_indices = [
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
start_indices = set()
|
||||||
|
text = ""
|
||||||
|
|
||||||
|
if "citation" in answer.metadata and answer.metadata["citation"] is not None:
|
||||||
|
context = answer.metadata["worker_log"]
|
||||||
|
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||||
|
for quote in fact_with_evidence.substring_quote:
|
||||||
|
matches = find_text(quote, context)
|
||||||
|
for match in matches:
|
||||||
|
split_indices.append(match[0])
|
||||||
|
split_indices.append(match[1])
|
||||||
|
start_indices.add(match[0])
|
||||||
|
split_indices = sorted(list(set(split_indices)))
|
||||||
|
spans = []
|
||||||
|
prev = 0
|
||||||
|
for index in split_indices:
|
||||||
|
if index > prev:
|
||||||
|
spans.append(context[prev:index])
|
||||||
|
prev = index
|
||||||
|
spans.append(context[split_indices[-1] :])
|
||||||
|
|
||||||
|
prev = 0
|
||||||
|
for span, start_idx in list(zip(spans, split_indices)):
|
||||||
|
if start_idx in start_indices:
|
||||||
|
text += Render.highlight(span)
|
||||||
|
else:
|
||||||
|
text += span
|
||||||
|
|
||||||
|
else:
|
||||||
|
text = answer.metadata["worker_log"]
|
||||||
|
|
||||||
|
# separate text by detect header: #Plan
|
||||||
|
for line in text.splitlines():
|
||||||
|
if line.startswith("#Plan"):
|
||||||
|
# line starts with #Plan should be marked as a new segment
|
||||||
|
new_segment = [line]
|
||||||
|
segments.append(new_segment)
|
||||||
|
elif line.startswith("#"):
|
||||||
|
# stop markdown from rendering big headers
|
||||||
|
line = "\\" + line
|
||||||
|
segments[-1].append(line)
|
||||||
|
else:
|
||||||
|
segments[-1].append(line)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for segment in segments:
|
||||||
|
outputs.append(
|
||||||
|
Document(
|
||||||
|
channel="info",
|
||||||
|
content=Render.collapsible(
|
||||||
|
header=segment[0],
|
||||||
|
content=Render.table("\n".join(segment[1:])),
|
||||||
|
open=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
async def ainvoke( # type: ignore
|
||||||
|
self, message, conv_id: str, history: list, **kwargs # type: ignore
|
||||||
|
) -> Document:
|
||||||
|
answer = self.agent(message, use_citation=True)
|
||||||
|
self.report_output(Document(content=answer.text, channel="chat"))
|
||||||
|
|
||||||
|
refined_citations = self.prepare_citation(answer)
|
||||||
|
for _ in refined_citations:
|
||||||
|
self.report_output(_)
|
||||||
|
|
||||||
|
self.report_output(None)
|
||||||
|
return answer
|
||||||
|
|
||||||
|
def stream( # type: ignore
|
||||||
|
self, message, conv_id: str, history: list, **kwargs # type: ignore
|
||||||
|
) -> Generator[Document, None, Document] | None:
|
||||||
|
if self.use_rewrite:
|
||||||
|
rewrite = self.rewrite_pipeline(question=message)
|
||||||
|
message = rewrite.text
|
||||||
|
yield Document(
|
||||||
|
channel="info",
|
||||||
|
content=f"Rewrote the message to: {rewrite.text}",
|
||||||
|
)
|
||||||
|
|
||||||
|
output_stream = GeneratorWrapper(
|
||||||
|
self.agent.stream(message, use_citation=self.enable_citation)
|
||||||
|
)
|
||||||
|
for item in output_stream:
|
||||||
|
if item.intermediate_steps:
|
||||||
|
for step in item.intermediate_steps:
|
||||||
|
yield Document(
|
||||||
|
channel="info",
|
||||||
|
content=self.format_info_panel(step["worker_log"]),
|
||||||
|
)
|
||||||
|
if item.text:
|
||||||
|
yield Document(channel="chat", content=item.text)
|
||||||
|
|
||||||
|
answer = output_stream.value
|
||||||
|
yield Document(channel="info", content=None)
|
||||||
|
refined_citations = self.prepare_citation(answer)
|
||||||
|
for _ in refined_citations:
|
||||||
|
yield _
|
||||||
|
|
||||||
|
return answer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline(
|
||||||
|
cls, settings: dict, states: dict, retrievers: list | None = None
|
||||||
|
) -> BaseReasoning:
|
||||||
|
_id = cls.get_info()["id"]
|
||||||
|
prefix = f"reasoning.options.{_id}"
|
||||||
|
pipeline = RewooAgentPipeline(retrievers=retrievers)
|
||||||
|
|
||||||
|
planner_llm_name = settings[f"{prefix}.planner_llm"]
|
||||||
|
planner_llm = llms.get(planner_llm_name, llms.get_default())
|
||||||
|
solver_llm_name = settings[f"{prefix}.solver_llm"]
|
||||||
|
solver_llm = llms.get(solver_llm_name, llms.get_default())
|
||||||
|
|
||||||
|
pipeline.agent.planner_llm = planner_llm
|
||||||
|
pipeline.agent.solver_llm = solver_llm
|
||||||
|
|
||||||
|
tools = []
|
||||||
|
for tool_name in settings[f"{prefix}.tools"]:
|
||||||
|
tool = TOOL_REGISTRY[tool_name]
|
||||||
|
if tool_name == "SearchDoc":
|
||||||
|
tool.retrievers = retrievers
|
||||||
|
elif tool_name == "LLM":
|
||||||
|
tool.llm = solver_llm
|
||||||
|
tools.append(tool)
|
||||||
|
pipeline.agent.plugins = tools
|
||||||
|
pipeline.agent.output_lang = {"en": "English", "ja": "Japanese"}.get(
|
||||||
|
settings["reasoning.lang"], "English"
|
||||||
|
)
|
||||||
|
pipeline.agent.prompt_template["Planner"] = PromptTemplate(
|
||||||
|
settings[f"{prefix}.planner_prompt"]
|
||||||
|
)
|
||||||
|
pipeline.agent.prompt_template["Solver"] = PromptTemplate(
|
||||||
|
settings[f"{prefix}.solver_prompt"]
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
||||||
|
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||||
|
pipeline.rewrite_pipeline.llm = (
|
||||||
|
planner_llm # TODO: separate llm for rewrite if needed
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_settings(cls) -> dict:
|
||||||
|
|
||||||
|
llm = ""
|
||||||
|
llm_choices = [("(default)", "")]
|
||||||
|
try:
|
||||||
|
llm_choices += [(_, _) for _ in llms.options().keys()]
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to get LLM options: {e}")
|
||||||
|
|
||||||
|
tool_choices = ["Wikipedia", "Google", "LLM", "SearchDoc"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"planner_llm": {
|
||||||
|
"name": "Language model for Planner",
|
||||||
|
"value": llm,
|
||||||
|
"component": "dropdown",
|
||||||
|
"choices": llm_choices,
|
||||||
|
"info": (
|
||||||
|
"The language model to use for planning. "
|
||||||
|
"This model will generate a plan based on the "
|
||||||
|
"instruction to find the answer."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"solver_llm": {
|
||||||
|
"name": "Language model for Solver",
|
||||||
|
"value": llm,
|
||||||
|
"component": "dropdown",
|
||||||
|
"choices": llm_choices,
|
||||||
|
"info": (
|
||||||
|
"The language model to use for solving. "
|
||||||
|
"This model will generate the answer based on the "
|
||||||
|
"plan generated by the planner and evidences found by the tools."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"highlight_citation": {
|
||||||
|
"name": "Highlight Citation",
|
||||||
|
"value": False,
|
||||||
|
"component": "checkbox",
|
||||||
|
},
|
||||||
|
"tools": {
|
||||||
|
"name": "Tools for knowledge retrieval",
|
||||||
|
"value": ["SearchDoc", "LLM"],
|
||||||
|
"component": "checkboxgroup",
|
||||||
|
"choices": tool_choices,
|
||||||
|
},
|
||||||
|
"planner_prompt": {
|
||||||
|
"name": "Planner Prompt",
|
||||||
|
"value": DEFAULT_PLANNER_PROMPT,
|
||||||
|
},
|
||||||
|
"solver_prompt": {
|
||||||
|
"name": "Solver Prompt",
|
||||||
|
"value": DEFAULT_SOLVER_PROMPT,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_info(cls) -> dict:
|
||||||
|
return {
|
||||||
|
"id": "ReWOO",
|
||||||
|
"name": "ReWOO Agent",
|
||||||
|
"description": (
|
||||||
|
"Implementing ReWOO paradigm " "https://arxiv.org/pdf/2305.18323.pdf"
|
||||||
|
),
|
||||||
|
}
|
9
libs/ktem/ktem/utils/generator.py
Normal file
9
libs/ktem/ktem/utils/generator.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
class Generator:
|
||||||
|
"""A generator that stores return value from another generator"""
|
||||||
|
|
||||||
|
def __init__(self, gen):
|
||||||
|
self.gen = gen
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
self.value = yield from self.gen
|
||||||
|
return self.value
|
Loading…
Reference in New Issue
Block a user