refractor agents (#100)
* refractor agents * minor cosmetic, add terminal ui for cli * pump to 0.3.4 * Add temporary path * fix unclose files in tests --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
parent
d9e925eb75
commit
797df5a69c
3
.github/workflows/unit-test.yaml
vendored
3
.github/workflows/unit-test.yaml
vendored
|
@ -6,6 +6,9 @@ on:
|
|||
push:
|
||||
branches: [main]
|
||||
|
||||
env:
|
||||
THEFLOW_TEMP_PATH: ./tmp
|
||||
|
||||
jobs:
|
||||
unit-test:
|
||||
if: ${{ !cancelled() }}
|
||||
|
|
|
@ -1,6 +1,25 @@
|
|||
from .base import AgentType, BaseAgent
|
||||
from .langchain import LangchainAgent
|
||||
from .base import BaseAgent
|
||||
from .io import AgentFinish, AgentOutput, AgentType, BaseScratchPad
|
||||
from .langchain_based import LangchainAgent
|
||||
from .react.agent import ReactAgent
|
||||
from .rewoo.agent import RewooAgent
|
||||
from .tools import BaseTool, ComponentTool, GoogleSearchTool, LLMTool, WikipediaTool
|
||||
|
||||
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"]
|
||||
__all__ = [
|
||||
# agent
|
||||
"BaseAgent",
|
||||
"ReactAgent",
|
||||
"RewooAgent",
|
||||
"LangchainAgent",
|
||||
# tool
|
||||
"BaseTool",
|
||||
"ComponentTool",
|
||||
"GoogleSearchTool",
|
||||
"WikipediaTool",
|
||||
"LLMTool",
|
||||
# io
|
||||
"AgentType",
|
||||
"AgentOutput",
|
||||
"AgentFinish",
|
||||
"BaseScratchPad",
|
||||
]
|
||||
|
|
|
@ -1,45 +1,13 @@
|
|||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms.chats.base import ChatLLM
|
||||
from kotaemon.llms.completions.base import LLM
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from .io import AgentOutput, AgentType
|
||||
from .tools import BaseTool
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
class AgentType(Enum):
|
||||
"""
|
||||
Enumerated type for agent types.
|
||||
"""
|
||||
|
||||
openai = "openai"
|
||||
openai_multi = "openai_multi"
|
||||
openai_tool = "openai_tool"
|
||||
self_ask = "self_ask"
|
||||
react = "react"
|
||||
rewoo = "rewoo"
|
||||
vanilla = "vanilla"
|
||||
|
||||
@staticmethod
|
||||
def get_agent_class(_type: "AgentType"):
|
||||
"""
|
||||
Get agent class from agent type.
|
||||
:param _type: agent type
|
||||
:return: agent class
|
||||
"""
|
||||
if _type == AgentType.rewoo:
|
||||
from .rewoo.agent import RewooAgent
|
||||
|
||||
return RewooAgent
|
||||
else:
|
||||
raise ValueError(f"Unknown agent type: {_type}")
|
||||
|
||||
|
||||
class BaseAgent(BaseComponent):
|
||||
"""Define base agent interface"""
|
||||
|
@ -47,13 +15,17 @@ class BaseAgent(BaseComponent):
|
|||
name: str = Param(help="Name of the agent.")
|
||||
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
|
||||
description: str = Param(
|
||||
help="Description used to tell the model how/when/why to use the agent. "
|
||||
"You can provide few-shot examples as a part of the description. This will be "
|
||||
help=(
|
||||
"Description used to tell the model how/when/why to use the agent. You can"
|
||||
" provide few-shot examples as a part of the description. This will be"
|
||||
" input to the prompt of LLM."
|
||||
)
|
||||
llm: Union[BaseLLM, dict[str, BaseLLM]] = Node(
|
||||
help="Specify LLM to be used in the model, cam be a dict to supply different "
|
||||
"LLMs to multiple purposes in the agent"
|
||||
)
|
||||
llm: Optional[BaseLLM] = Node(
|
||||
help=(
|
||||
"LLM to be used for the agent (optional). LLM must implement BaseLLM"
|
||||
" interface."
|
||||
)
|
||||
)
|
||||
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
|
||||
help="A prompt template or a dict to supply different prompt to the agent"
|
||||
|
@ -63,6 +35,25 @@ class BaseAgent(BaseComponent):
|
|||
help="List of plugins / tools to be used in the agent",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def safeguard_run(run_func, *args, **kwargs):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return run_func(self, *args, **kwargs)
|
||||
except Exception as e:
|
||||
return AgentOutput(
|
||||
text="",
|
||||
agent_type=self.agent_type,
|
||||
status="failed",
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
def add_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Helper method to add tools and update agent state if needed"""
|
||||
self.plugins.extend(tools)
|
||||
|
||||
def run(self, *args, **kwargs) -> AgentOutput | list[AgentOutput]:
|
||||
"""Run the component."""
|
||||
raise NotImplementedError()
|
||||
|
|
3
knowledgehub/agents/io/__init__.py
Normal file
3
knowledgehub/agents/io/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
|||
from .base import AgentAction, AgentFinish, AgentOutput, AgentType, BaseScratchPad
|
||||
|
||||
__all__ = ["AgentOutput", "AgentFinish", "BaseScratchPad", "AgentType", "AgentAction"]
|
|
@ -2,7 +2,12 @@ import json
|
|||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, NamedTuple, Union
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from kotaemon.base import LLMInterface
|
||||
|
||||
|
||||
def check_log():
|
||||
|
@ -14,6 +19,20 @@ def check_log():
|
|||
return os.environ.get("LOG_PATH", None) is not None
|
||||
|
||||
|
||||
class AgentType(Enum):
|
||||
"""
|
||||
Enumerated type for agent types.
|
||||
"""
|
||||
|
||||
openai = "openai"
|
||||
openai_multi = "openai_multi"
|
||||
openai_tool = "openai_tool"
|
||||
self_ask = "self_ask"
|
||||
react = "react"
|
||||
rewoo = "rewoo"
|
||||
vanilla = "vanilla"
|
||||
|
||||
|
||||
class BaseScratchPad:
|
||||
"""
|
||||
Base class for output handlers.
|
||||
|
@ -217,3 +236,20 @@ class AgentFinish(NamedTuple):
|
|||
|
||||
return_values: dict
|
||||
log: str
|
||||
|
||||
|
||||
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
|
||||
"""Output from an agent.
|
||||
|
||||
Args:
|
||||
text: The text output from the agent.
|
||||
agent_type: The type of agent.
|
||||
status: The status after executing the agent.
|
||||
error: The error message if any.
|
||||
"""
|
||||
|
||||
text: str
|
||||
type: str = "agent"
|
||||
agent_type: AgentType
|
||||
status: Literal["finished", "stopped", "failed"]
|
||||
error: Optional[str] = None
|
|
@ -4,12 +4,11 @@ from langchain.agents import AgentType as LCAgentType
|
|||
from langchain.agents import initialize_agent
|
||||
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
||||
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.llms.chats.base import ChatLLM
|
||||
from kotaemon.llms.completions.base import LLM
|
||||
from kotaemon.llms import LLM, ChatLLM
|
||||
|
||||
from .base import AgentType, BaseAgent
|
||||
from .base import BaseAgent
|
||||
from .io import AgentOutput, AgentType
|
||||
from .tools import BaseTool
|
||||
|
||||
|
||||
class LangchainAgent(BaseAgent):
|
||||
|
@ -54,7 +53,9 @@ class LangchainAgent(BaseAgent):
|
|||
# reinit Langchain AgentExecutor
|
||||
self.agent = initialize_agent(
|
||||
langchain_plugins,
|
||||
self.llm._obj,
|
||||
# TODO: could cause bugs for non-langchain llms
|
||||
# related to https://github.com/Cinnamon/kotaemon/issues/73
|
||||
self.llm._obj, # type: ignore
|
||||
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
||||
handle_parsing_errors=True,
|
||||
verbose=True,
|
||||
|
@ -65,17 +66,16 @@ class LangchainAgent(BaseAgent):
|
|||
self.update_agent_tools()
|
||||
return
|
||||
|
||||
def run(self, instruction: str) -> Document:
|
||||
def run(self, instruction: str) -> AgentOutput:
|
||||
assert (
|
||||
self.agent is not None
|
||||
), "Lanchain AgentExecutor is not correclty initialized"
|
||||
|
||||
# Langchain AgentExecutor call
|
||||
output = self.agent(instruction)["output"]
|
||||
return Document(
|
||||
|
||||
return AgentOutput(
|
||||
text=output,
|
||||
metadata={
|
||||
"agent": "langchain",
|
||||
"cost": 0.0,
|
||||
"usage": 0,
|
||||
},
|
||||
agent_type=self.agent_type,
|
||||
status="finished",
|
||||
)
|
|
@ -4,12 +4,11 @@ from typing import Optional
|
|||
|
||||
from theflow import Param
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.agents.base import BaseAgent, BaseLLM
|
||||
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
||||
from ..output.base import AgentAction, AgentFinish
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
|
||||
|
@ -22,7 +21,7 @@ class ReactAgent(BaseAgent):
|
|||
name: str = "ReactAgent"
|
||||
agent_type: AgentType = AgentType.react
|
||||
description: str = "ReactAgent for answering multi-step reasoning questions"
|
||||
llm: BaseLLM | dict[str, BaseLLM]
|
||||
llm: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
||||
|
@ -34,7 +33,7 @@ class ReactAgent(BaseAgent):
|
|||
default_callback=lambda _: [],
|
||||
help="List of AgentAction and observation (tool) output",
|
||||
)
|
||||
max_iterations = 10
|
||||
max_iterations: int = 10
|
||||
strict_decode: bool = False
|
||||
|
||||
def _compose_plugin_description(self) -> str:
|
||||
|
@ -141,7 +140,7 @@ class ReactAgent(BaseAgent):
|
|||
"""
|
||||
self.intermediate_steps = []
|
||||
|
||||
def run(self, instruction, max_iterations=None):
|
||||
def run(self, instruction, max_iterations=None) -> AgentOutput:
|
||||
"""
|
||||
Run the agent with the given instruction.
|
||||
|
||||
|
@ -161,11 +160,15 @@ class ReactAgent(BaseAgent):
|
|||
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
status = "failed"
|
||||
response_text = None
|
||||
|
||||
for _ in range(max_iterations):
|
||||
for step_count in range(1, max_iterations + 1):
|
||||
prompt = self._compose_prompt(instruction)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
|
||||
response = self.llm(
|
||||
prompt, stop=["Observation:"]
|
||||
) # could cause bugs if llm doesn't have `stop` as a parameter
|
||||
response_text = response.text
|
||||
logging.info(f"Response: {response_text}")
|
||||
action_step = self._parse_output(response_text)
|
||||
|
@ -185,13 +188,18 @@ class ReactAgent(BaseAgent):
|
|||
|
||||
self.intermediate_steps.append((action_step, result))
|
||||
if is_finished_chain:
|
||||
logging.info(f"Finished after {step_count} steps.")
|
||||
status = "finished"
|
||||
break
|
||||
else:
|
||||
status = "stopped"
|
||||
|
||||
return Document(
|
||||
return AgentOutput(
|
||||
text=response_text,
|
||||
metadata={
|
||||
"agent": "react",
|
||||
"cost": total_cost,
|
||||
"usage": total_token,
|
||||
},
|
||||
agent_type=self.agent_type,
|
||||
status=status,
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
intermediate_steps=self.intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
|
|
@ -3,15 +3,15 @@ import re
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from theflow import Param
|
||||
from theflow import Node, Param
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.agents.base import BaseAgent
|
||||
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.agents.utils import get_plugin_response_content
|
||||
from kotaemon.indices.qa import CitationPipeline
|
||||
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
||||
from ..output.base import BaseScratchPad
|
||||
from ..utils import get_plugin_response_content
|
||||
from .planner import Planner
|
||||
from .solver import Solver
|
||||
|
||||
|
@ -23,7 +23,8 @@ class RewooAgent(BaseAgent):
|
|||
name: str = "RewooAgent"
|
||||
agent_type: AgentType = AgentType.rewoo
|
||||
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx}
|
||||
planner_llm: BaseLLM
|
||||
solver_llm: BaseLLM
|
||||
prompt_template: dict[str, PromptTemplate] = Param(
|
||||
default_callback=lambda _: {},
|
||||
help="A dict to supply different prompt to the agent.",
|
||||
|
@ -35,17 +36,22 @@ class RewooAgent(BaseAgent):
|
|||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||
)
|
||||
|
||||
def _get_llms(self):
|
||||
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM):
|
||||
return {"Planner": self.llm, "Solver": self.llm}
|
||||
elif (
|
||||
isinstance(self.llm, dict)
|
||||
and "Planner" in self.llm
|
||||
and "Solver" in self.llm
|
||||
):
|
||||
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
|
||||
else:
|
||||
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
|
||||
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
||||
def planner(self):
|
||||
return Planner(
|
||||
model=self.planner_llm,
|
||||
plugins=self.plugins,
|
||||
prompt_template=self.prompt_template.get("Planner", None),
|
||||
examples=self.examples.get("Planner", None),
|
||||
)
|
||||
|
||||
@Node.auto(depends_on=["solver_llm", "prompt_template", "examples"])
|
||||
def solver(self):
|
||||
return Solver(
|
||||
model=self.solver_llm,
|
||||
prompt_template=self.prompt_template.get("Solver", None),
|
||||
examples=self.examples.get("Solver", None),
|
||||
)
|
||||
|
||||
def _parse_plan_map(
|
||||
self, planner_response: str
|
||||
|
@ -76,13 +82,16 @@ class RewooAgent(BaseAgent):
|
|||
|
||||
plan_to_es: dict[str, list[str]] = dict()
|
||||
plans: dict[str, str] = dict()
|
||||
prev_key = ""
|
||||
for line in valid_chunk:
|
||||
if line.startswith("#Plan"):
|
||||
plan = line.split(":", 1)[0].strip()
|
||||
plans[plan] = line.split(":", 1)[1].strip()
|
||||
plan_to_es[plan] = []
|
||||
elif line.startswith("#E"):
|
||||
plan_to_es[plan].append(line.split(":", 1)[0].strip())
|
||||
key, description = line.split(":", 1)
|
||||
key = key.strip()
|
||||
if key.startswith("#Plan"):
|
||||
plans[key] = description.strip()
|
||||
plan_to_es[key] = []
|
||||
prev_key = key
|
||||
elif key.startswith("#E"):
|
||||
plan_to_es[prev_key].append(key)
|
||||
|
||||
return plan_to_es, plans
|
||||
|
||||
|
@ -218,7 +227,8 @@ class RewooAgent(BaseAgent):
|
|||
if p.name == name:
|
||||
return p
|
||||
|
||||
def run(self, instruction: str, use_citation: bool = False) -> Document:
|
||||
@BaseAgent.safeguard_run
|
||||
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
|
||||
"""
|
||||
Run the agent with a given instruction.
|
||||
"""
|
||||
|
@ -226,27 +236,12 @@ class RewooAgent(BaseAgent):
|
|||
total_cost = 0.0
|
||||
total_token = 0
|
||||
|
||||
planner_llm = self._get_llms()["Planner"]
|
||||
solver_llm = self._get_llms()["Solver"]
|
||||
|
||||
planner = Planner(
|
||||
model=planner_llm,
|
||||
plugins=self.plugins,
|
||||
prompt_template=self.prompt_template.get("Planner", None),
|
||||
examples=self.examples.get("Planner", None),
|
||||
)
|
||||
solver = Solver(
|
||||
model=solver_llm,
|
||||
prompt_template=self.prompt_template.get("Solver", None),
|
||||
examples=self.examples.get("Solver", None),
|
||||
)
|
||||
|
||||
# Plan
|
||||
planner_output = planner(instruction)
|
||||
plannner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
|
||||
planner_output = self.planner(instruction)
|
||||
planner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(planner_text_output)
|
||||
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||
plannner_text_output
|
||||
planner_text_output
|
||||
)
|
||||
|
||||
# Work
|
||||
|
@ -260,20 +255,19 @@ class RewooAgent(BaseAgent):
|
|||
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||
|
||||
# Solve
|
||||
solver_output = solver(instruction, worker_log)
|
||||
solver_output = self.solver(instruction, worker_log)
|
||||
solver_output_text = solver_output.text
|
||||
if use_citation:
|
||||
citation_pipeline = CitationPipeline(llm=solver_llm)
|
||||
citation_pipeline = CitationPipeline(llm=self.solver_llm)
|
||||
citation = citation_pipeline(context=worker_log, question=instruction)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
return Document(
|
||||
return AgentOutput(
|
||||
text=solver_output_text,
|
||||
metadata={
|
||||
"agent": "react",
|
||||
"cost": total_cost,
|
||||
"usage": total_token,
|
||||
"citation": citation,
|
||||
},
|
||||
agent_type=self.agent_type,
|
||||
status="finished",
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
citation=citation,
|
||||
)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.agents.base import BaseLLM, BaseTool
|
||||
from kotaemon.agents.io import BaseScratchPad
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from ..base import BaseLLM, BaseTool
|
||||
from ..output.base import BaseScratchPad
|
||||
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
|
||||
|
||||
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
from kotaemon.agents.io import BaseScratchPad
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms import PromptTemplate
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from ..base import BaseLLM
|
||||
from ..output.base import BaseScratchPad
|
||||
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
|
||||
|
||||
|
||||
|
|
|
@ -11,8 +11,8 @@ class GoogleSearchArgs(BaseModel):
|
|||
|
||||
|
||||
class GoogleSearchTool(BaseTool):
|
||||
name = "google_search"
|
||||
description = (
|
||||
name: str = "google_search"
|
||||
description: str = (
|
||||
"A search engine retrieving top search results as snippets from Google. "
|
||||
"Input should be a search query."
|
||||
)
|
||||
|
|
|
@ -14,8 +14,8 @@ class LLMArgs(BaseModel):
|
|||
|
||||
|
||||
class LLMTool(BaseTool):
|
||||
name = "llm"
|
||||
description = (
|
||||
name: str = "llm"
|
||||
description: str = (
|
||||
"A pretrained LLM like yourself. Useful when you need to act with "
|
||||
"general world knowledge and common sense. Prioritize it when you "
|
||||
"are confident in solving the problem "
|
||||
|
|
|
@ -48,8 +48,8 @@ class WikipediaArgs(BaseModel):
|
|||
class WikipediaTool(BaseTool):
|
||||
"""Tool that adds the capability to query the Wikipedia API."""
|
||||
|
||||
name = "wikipedia"
|
||||
description = (
|
||||
name: str = "wikipedia"
|
||||
description: str = (
|
||||
"Search engine from Wikipedia, retrieving relevant wiki page. "
|
||||
"Useful when you need to get holistic knowledge about people, "
|
||||
"places, companies, historical events, or other subjects. "
|
||||
|
|
|
@ -114,6 +114,7 @@ class LLMInterface(AIMessage):
|
|||
completion_tokens: int = -1
|
||||
total_tokens: int = -1
|
||||
prompt_tokens: int = -1
|
||||
total_cost: float = 0
|
||||
logits: list[list[float]] = Field(default_factory=list)
|
||||
messages: list[AIMessage] = Field(default_factory=list)
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import os
|
|||
|
||||
import click
|
||||
import yaml
|
||||
from trogon import tui
|
||||
|
||||
|
||||
# check if the output is not a .yml file -> raise error
|
||||
|
@ -14,6 +15,7 @@ def check_config_format(config):
|
|||
raise ValueError("config must be yaml format.")
|
||||
|
||||
|
||||
@tui(command="ui", help="Open the terminal UI") # generate the terminal UI
|
||||
@click.group()
|
||||
def main():
|
||||
pass
|
||||
|
@ -56,8 +58,10 @@ def export(export_path, output):
|
|||
@click.option(
|
||||
"--username",
|
||||
required=False,
|
||||
help="Username for the user. If not provided, the promptui will not have "
|
||||
"authentication.",
|
||||
help=(
|
||||
"Username for the user. If not provided, the promptui will not have "
|
||||
"authentication."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--password",
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .base import ChatLLM
|
||||
from .langchain_based import AzureChatOpenAI
|
||||
from .langchain_based import AzureChatOpenAI, LCChatMixin
|
||||
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI"]
|
||||
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin"]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .base import LLM
|
||||
from .langchain_based import AzureOpenAI, OpenAI
|
||||
from .langchain_based import AzureOpenAI, LCCompletionMixin, OpenAI
|
||||
|
||||
__all__ = ["LLM", "OpenAI", "AzureOpenAI"]
|
||||
__all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin"]
|
||||
|
|
|
@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
|
|||
# metadata and dependencies
|
||||
[project]
|
||||
name = "kotaemon"
|
||||
version = "0.3.3"
|
||||
version = "0.3.4"
|
||||
requires-python = ">= 3.10"
|
||||
description = "Kotaemon core library for AI development."
|
||||
dependencies = [
|
||||
|
@ -24,6 +24,7 @@ dependencies = [
|
|||
"cookiecutter",
|
||||
"click",
|
||||
"pandas",
|
||||
"trogon",
|
||||
]
|
||||
readme = "README.md"
|
||||
license = { text = "MIT License" }
|
||||
|
|
|
@ -3,52 +3,34 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from kotaemon.agents.base import AgentType
|
||||
from kotaemon.agents.langchain import LangchainAgent
|
||||
from kotaemon.agents.react import ReactAgent
|
||||
from kotaemon.agents.rewoo import RewooAgent
|
||||
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool
|
||||
from kotaemon.agents import (
|
||||
AgentType,
|
||||
BaseTool,
|
||||
GoogleSearchTool,
|
||||
LangchainAgent,
|
||||
LLMTool,
|
||||
ReactAgent,
|
||||
RewooAgent,
|
||||
WikipediaTool,
|
||||
)
|
||||
from kotaemon.llms import AzureChatOpenAI
|
||||
|
||||
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
|
||||
|
||||
|
||||
_openai_chat_completion_responses_rewoo = [
|
||||
ChatCompletion.parse_obj(
|
||||
{
|
||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||
"object": "chat.completion",
|
||||
"created": 1692338378,
|
||||
"model": "gpt-35-turbo",
|
||||
"system_fingerprint": None,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
}
|
||||
)
|
||||
for text in [
|
||||
(
|
||||
REWOO_VALID_PLAN = (
|
||||
"#Plan1: Search for Cinnamon AI company on Google\n"
|
||||
"#E1: google_search[Cinnamon AI company]\n"
|
||||
"#Plan2: Search for Cinnamon on Wikipedia\n"
|
||||
"#E2: wikipedia[Cinnamon]\n"
|
||||
),
|
||||
FINAL_RESPONSE_TEXT,
|
||||
]
|
||||
]
|
||||
)
|
||||
REWOO_INVALID_PLAN = (
|
||||
"#E1: google_search[Cinnamon AI company]\n"
|
||||
"#Plan2: Search for Cinnamon on Wikipedia\n"
|
||||
"#E2: wikipedia[Cinnamon]\n"
|
||||
)
|
||||
|
||||
_openai_chat_completion_responses_react = [
|
||||
ChatCompletion.parse_obj(
|
||||
|
||||
def generate_chat_completion_obj(text):
|
||||
return ChatCompletion.parse_obj(
|
||||
{
|
||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||
"object": "chat.completion",
|
||||
|
@ -70,6 +52,20 @@ _openai_chat_completion_responses_react = [
|
|||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
_openai_chat_completion_responses_rewoo = [
|
||||
generate_chat_completion_obj(text=text)
|
||||
for text in [REWOO_VALID_PLAN, FINAL_RESPONSE_TEXT]
|
||||
]
|
||||
|
||||
_openai_chat_completion_responses_rewoo_error = [
|
||||
generate_chat_completion_obj(text=text)
|
||||
for text in [REWOO_INVALID_PLAN, FINAL_RESPONSE_TEXT]
|
||||
]
|
||||
|
||||
_openai_chat_completion_responses_react = [
|
||||
generate_chat_completion_obj(text=text)
|
||||
for text in [
|
||||
(
|
||||
"I don't have prior knowledge about Cinnamon AI company, "
|
||||
|
@ -91,28 +87,7 @@ _openai_chat_completion_responses_react = [
|
|||
]
|
||||
|
||||
_openai_chat_completion_responses_react_langchain_tool = [
|
||||
ChatCompletion.parse_obj(
|
||||
{
|
||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||
"object": "chat.completion",
|
||||
"created": 1692338378,
|
||||
"model": "gpt-35-turbo",
|
||||
"system_fingerprint": None,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
}
|
||||
)
|
||||
generate_chat_completion_obj(text=text)
|
||||
for text in [
|
||||
(
|
||||
"I don't have prior knowledge about Cinnamon AI company, "
|
||||
|
@ -145,6 +120,25 @@ def llm():
|
|||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.resources.chat.completions.Completions.create",
|
||||
side_effect=_openai_chat_completion_responses_rewoo_error,
|
||||
)
|
||||
def test_agent_fail(openai_completion, llm, mock_google_search):
|
||||
plugins = [
|
||||
GoogleSearchTool(),
|
||||
WikipediaTool(),
|
||||
LLMTool(llm=llm),
|
||||
]
|
||||
|
||||
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
|
||||
|
||||
response = agent("Tell me about Cinnamon AI company")
|
||||
openai_completion.assert_called()
|
||||
assert not response
|
||||
assert response.status == "failed"
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.resources.chat.completions.Completions.create",
|
||||
side_effect=_openai_chat_completion_responses_rewoo,
|
||||
|
@ -156,7 +150,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search):
|
|||
LLMTool(llm=llm),
|
||||
]
|
||||
|
||||
agent = RewooAgent(llm=llm, plugins=plugins)
|
||||
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
|
||||
|
||||
response = agent("Tell me about Cinnamon AI company")
|
||||
openai_completion.assert_called()
|
||||
|
|
|
@ -110,7 +110,7 @@ class TestInMemoryVectorStore:
|
|||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
db.delete(["3"])
|
||||
db.save(save_path=tmp_path / "test_save_load_delete.json")
|
||||
f = open(tmp_path / "test_save_load_delete.json")
|
||||
with open(tmp_path / "test_save_load_delete.json") as f:
|
||||
data = json.load(f)
|
||||
assert (
|
||||
"1" and "2" in data["text_id_to_ref_doc_id"]
|
||||
|
@ -136,7 +136,7 @@ class TestSimpleFileVectorStore:
|
|||
db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
|
||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||
db.delete(["3"])
|
||||
f = open(tmp_path / "test_save_load_delete.json")
|
||||
with open(tmp_path / "test_save_load_delete.json") as f:
|
||||
data = json.load(f)
|
||||
assert (
|
||||
"1" and "2" in data["text_id_to_ref_doc_id"]
|
||||
|
|
Loading…
Reference in New Issue
Block a user