refractor agents (#100)

* refractor agents

* minor cosmetic, add terminal ui for cli

* pump to 0.3.4

* Add temporary path

* fix unclose files in tests

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin 2023-12-06 17:06:29 +07:00 committed by GitHub
parent d9e925eb75
commit 797df5a69c
21 changed files with 281 additions and 228 deletions

View File

@ -6,6 +6,9 @@ on:
push:
branches: [main]
env:
THEFLOW_TEMP_PATH: ./tmp
jobs:
unit-test:
if: ${{ !cancelled() }}

View File

@ -1,6 +1,25 @@
from .base import AgentType, BaseAgent
from .langchain import LangchainAgent
from .base import BaseAgent
from .io import AgentFinish, AgentOutput, AgentType, BaseScratchPad
from .langchain_based import LangchainAgent
from .react.agent import ReactAgent
from .rewoo.agent import RewooAgent
from .tools import BaseTool, ComponentTool, GoogleSearchTool, LLMTool, WikipediaTool
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"]
__all__ = [
# agent
"BaseAgent",
"ReactAgent",
"RewooAgent",
"LangchainAgent",
# tool
"BaseTool",
"ComponentTool",
"GoogleSearchTool",
"WikipediaTool",
"LLMTool",
# io
"AgentType",
"AgentOutput",
"AgentFinish",
"BaseScratchPad",
]

View File

@ -1,45 +1,13 @@
from enum import Enum
from typing import Optional, Union
from theflow import Node, Param
from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.llms import BaseLLM, PromptTemplate
from .io import AgentOutput, AgentType
from .tools import BaseTool
BaseLLM = Union[ChatLLM, LLM]
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
@staticmethod
def get_agent_class(_type: "AgentType"):
"""
Get agent class from agent type.
:param _type: agent type
:return: agent class
"""
if _type == AgentType.rewoo:
from .rewoo.agent import RewooAgent
return RewooAgent
else:
raise ValueError(f"Unknown agent type: {_type}")
class BaseAgent(BaseComponent):
"""Define base agent interface"""
@ -47,13 +15,17 @@ class BaseAgent(BaseComponent):
name: str = Param(help="Name of the agent.")
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
description: str = Param(
help="Description used to tell the model how/when/why to use the agent. "
"You can provide few-shot examples as a part of the description. This will be "
"input to the prompt of LLM."
help=(
"Description used to tell the model how/when/why to use the agent. You can"
" provide few-shot examples as a part of the description. This will be"
" input to the prompt of LLM."
)
)
llm: Union[BaseLLM, dict[str, BaseLLM]] = Node(
help="Specify LLM to be used in the model, cam be a dict to supply different "
"LLMs to multiple purposes in the agent"
llm: Optional[BaseLLM] = Node(
help=(
"LLM to be used for the agent (optional). LLM must implement BaseLLM"
" interface."
)
)
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
help="A prompt template or a dict to supply different prompt to the agent"
@ -63,6 +35,25 @@ class BaseAgent(BaseComponent):
help="List of plugins / tools to be used in the agent",
)
@staticmethod
def safeguard_run(run_func, *args, **kwargs):
def wrapper(self, *args, **kwargs):
try:
return run_func(self, *args, **kwargs)
except Exception as e:
return AgentOutput(
text="",
agent_type=self.agent_type,
status="failed",
error=str(e),
)
return wrapper
def add_tools(self, tools: list[BaseTool]) -> None:
"""Helper method to add tools and update agent state if needed"""
self.plugins.extend(tools)
def run(self, *args, **kwargs) -> AgentOutput | list[AgentOutput]:
"""Run the component."""
raise NotImplementedError()

View File

@ -0,0 +1,3 @@
from .base import AgentAction, AgentFinish, AgentOutput, AgentType, BaseScratchPad
__all__ = ["AgentOutput", "AgentFinish", "BaseScratchPad", "AgentType", "AgentAction"]

View File

@ -2,7 +2,12 @@ import json
import logging
import os
from dataclasses import dataclass
from typing import Any, Dict, NamedTuple, Union
from enum import Enum
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
from pydantic import Extra
from kotaemon.base import LLMInterface
def check_log():
@ -14,6 +19,20 @@ def check_log():
return os.environ.get("LOG_PATH", None) is not None
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
class BaseScratchPad:
"""
Base class for output handlers.
@ -217,3 +236,20 @@ class AgentFinish(NamedTuple):
return_values: dict
log: str
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
"""Output from an agent.
Args:
text: The text output from the agent.
agent_type: The type of agent.
status: The status after executing the agent.
error: The error message if any.
"""
text: str
type: str = "agent"
agent_type: AgentType
status: Literal["finished", "stopped", "failed"]
error: Optional[str] = None

View File

@ -4,12 +4,11 @@ from langchain.agents import AgentType as LCAgentType
from langchain.agents import initialize_agent
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
from kotaemon.agents.tools import BaseTool
from kotaemon.base.schema import Document
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.llms import LLM, ChatLLM
from .base import AgentType, BaseAgent
from .base import BaseAgent
from .io import AgentOutput, AgentType
from .tools import BaseTool
class LangchainAgent(BaseAgent):
@ -54,7 +53,9 @@ class LangchainAgent(BaseAgent):
# reinit Langchain AgentExecutor
self.agent = initialize_agent(
langchain_plugins,
self.llm._obj,
# TODO: could cause bugs for non-langchain llms
# related to https://github.com/Cinnamon/kotaemon/issues/73
self.llm._obj, # type: ignore
agent=self.AGENT_TYPE_MAP[self.agent_type],
handle_parsing_errors=True,
verbose=True,
@ -65,17 +66,16 @@ class LangchainAgent(BaseAgent):
self.update_agent_tools()
return
def run(self, instruction: str) -> Document:
def run(self, instruction: str) -> AgentOutput:
assert (
self.agent is not None
), "Lanchain AgentExecutor is not correclty initialized"
# Langchain AgentExecutor call
output = self.agent(instruction)["output"]
return Document(
return AgentOutput(
text=output,
metadata={
"agent": "langchain",
"cost": 0.0,
"usage": 0,
},
agent_type=self.agent_type,
status="finished",
)

View File

@ -4,12 +4,11 @@ from typing import Optional
from theflow import Param
from kotaemon.base.schema import Document
from kotaemon.agents.base import BaseAgent, BaseLLM
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
from kotaemon.agents.tools import BaseTool
from kotaemon.llms import PromptTemplate
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import AgentAction, AgentFinish
FINAL_ANSWER_ACTION = "Final Answer:"
@ -22,7 +21,7 @@ class ReactAgent(BaseAgent):
name: str = "ReactAgent"
agent_type: AgentType = AgentType.react
description: str = "ReactAgent for answering multi-step reasoning questions"
llm: BaseLLM | dict[str, BaseLLM]
llm: BaseLLM
prompt_template: Optional[PromptTemplate] = None
plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="List of tools to be used in the agent. "
@ -34,7 +33,7 @@ class ReactAgent(BaseAgent):
default_callback=lambda _: [],
help="List of AgentAction and observation (tool) output",
)
max_iterations = 10
max_iterations: int = 10
strict_decode: bool = False
def _compose_plugin_description(self) -> str:
@ -141,7 +140,7 @@ class ReactAgent(BaseAgent):
"""
self.intermediate_steps = []
def run(self, instruction, max_iterations=None):
def run(self, instruction, max_iterations=None) -> AgentOutput:
"""
Run the agent with the given instruction.
@ -161,11 +160,15 @@ class ReactAgent(BaseAgent):
logging.info(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
status = "failed"
response_text = None
for _ in range(max_iterations):
for step_count in range(1, max_iterations + 1):
prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}")
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
response = self.llm(
prompt, stop=["Observation:"]
) # could cause bugs if llm doesn't have `stop` as a parameter
response_text = response.text
logging.info(f"Response: {response_text}")
action_step = self._parse_output(response_text)
@ -185,13 +188,18 @@ class ReactAgent(BaseAgent):
self.intermediate_steps.append((action_step, result))
if is_finished_chain:
logging.info(f"Finished after {step_count} steps.")
status = "finished"
break
else:
status = "stopped"
return Document(
return AgentOutput(
text=response_text,
metadata={
"agent": "react",
"cost": total_cost,
"usage": total_token,
},
agent_type=self.agent_type,
status=status,
total_tokens=total_token,
total_cost=total_cost,
intermediate_steps=self.intermediate_steps,
max_iterations=max_iterations,
)

View File

@ -3,15 +3,15 @@ import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from theflow import Param
from theflow import Node, Param
from kotaemon.base.schema import Document
from kotaemon.agents.base import BaseAgent
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
from kotaemon.agents.tools import BaseTool
from kotaemon.agents.utils import get_plugin_response_content
from kotaemon.indices.qa import CitationPipeline
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
from kotaemon.llms import BaseLLM, PromptTemplate
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from ..utils import get_plugin_response_content
from .planner import Planner
from .solver import Solver
@ -23,7 +23,8 @@ class RewooAgent(BaseAgent):
name: str = "RewooAgent"
agent_type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions"
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx}
planner_llm: BaseLLM
solver_llm: BaseLLM
prompt_template: dict[str, PromptTemplate] = Param(
default_callback=lambda _: {},
help="A dict to supply different prompt to the agent.",
@ -35,17 +36,22 @@ class RewooAgent(BaseAgent):
default_callback=lambda _: {}, help="Examples to be used in the agent."
)
def _get_llms(self):
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM):
return {"Planner": self.llm, "Solver": self.llm}
elif (
isinstance(self.llm, dict)
and "Planner" in self.llm
and "Solver" in self.llm
):
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
else:
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
def planner(self):
return Planner(
model=self.planner_llm,
plugins=self.plugins,
prompt_template=self.prompt_template.get("Planner", None),
examples=self.examples.get("Planner", None),
)
@Node.auto(depends_on=["solver_llm", "prompt_template", "examples"])
def solver(self):
return Solver(
model=self.solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
)
def _parse_plan_map(
self, planner_response: str
@ -76,13 +82,16 @@ class RewooAgent(BaseAgent):
plan_to_es: dict[str, list[str]] = dict()
plans: dict[str, str] = dict()
prev_key = ""
for line in valid_chunk:
if line.startswith("#Plan"):
plan = line.split(":", 1)[0].strip()
plans[plan] = line.split(":", 1)[1].strip()
plan_to_es[plan] = []
elif line.startswith("#E"):
plan_to_es[plan].append(line.split(":", 1)[0].strip())
key, description = line.split(":", 1)
key = key.strip()
if key.startswith("#Plan"):
plans[key] = description.strip()
plan_to_es[key] = []
prev_key = key
elif key.startswith("#E"):
plan_to_es[prev_key].append(key)
return plan_to_es, plans
@ -218,7 +227,8 @@ class RewooAgent(BaseAgent):
if p.name == name:
return p
def run(self, instruction: str, use_citation: bool = False) -> Document:
@BaseAgent.safeguard_run
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
"""
Run the agent with a given instruction.
"""
@ -226,27 +236,12 @@ class RewooAgent(BaseAgent):
total_cost = 0.0
total_token = 0
planner_llm = self._get_llms()["Planner"]
solver_llm = self._get_llms()["Solver"]
planner = Planner(
model=planner_llm,
plugins=self.plugins,
prompt_template=self.prompt_template.get("Planner", None),
examples=self.examples.get("Planner", None),
)
solver = Solver(
model=solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
)
# Plan
planner_output = planner(instruction)
plannner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
planner_output = self.planner(instruction)
planner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(planner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences(
plannner_text_output
planner_text_output
)
# Work
@ -260,20 +255,19 @@ class RewooAgent(BaseAgent):
worker_log += f"{e}: {worker_evidences[e]}\n"
# Solve
solver_output = solver(instruction, worker_log)
solver_output = self.solver(instruction, worker_log)
solver_output_text = solver_output.text
if use_citation:
citation_pipeline = CitationPipeline(llm=solver_llm)
citation_pipeline = CitationPipeline(llm=self.solver_llm)
citation = citation_pipeline(context=worker_log, question=instruction)
else:
citation = None
return Document(
return AgentOutput(
text=solver_output_text,
metadata={
"agent": "react",
"cost": total_cost,
"usage": total_token,
"citation": citation,
},
agent_type=self.agent_type,
status="finished",
total_tokens=total_token,
total_cost=total_cost,
citation=citation,
)

View File

@ -1,10 +1,10 @@
from typing import Any, List, Optional, Union
from kotaemon.agents.base import BaseLLM, BaseTool
from kotaemon.agents.io import BaseScratchPad
from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate
from ..base import BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt

View File

@ -1,10 +1,9 @@
from typing import Any, List, Optional, Union
from kotaemon.agents.io import BaseScratchPad
from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate
from kotaemon.llms import BaseLLM, PromptTemplate
from ..base import BaseLLM
from ..output.base import BaseScratchPad
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt

View File

@ -11,8 +11,8 @@ class GoogleSearchArgs(BaseModel):
class GoogleSearchTool(BaseTool):
name = "google_search"
description = (
name: str = "google_search"
description: str = (
"A search engine retrieving top search results as snippets from Google. "
"Input should be a search query."
)

View File

@ -14,8 +14,8 @@ class LLMArgs(BaseModel):
class LLMTool(BaseTool):
name = "llm"
description = (
name: str = "llm"
description: str = (
"A pretrained LLM like yourself. Useful when you need to act with "
"general world knowledge and common sense. Prioritize it when you "
"are confident in solving the problem "

View File

@ -48,8 +48,8 @@ class WikipediaArgs(BaseModel):
class WikipediaTool(BaseTool):
"""Tool that adds the capability to query the Wikipedia API."""
name = "wikipedia"
description = (
name: str = "wikipedia"
description: str = (
"Search engine from Wikipedia, retrieving relevant wiki page. "
"Useful when you need to get holistic knowledge about people, "
"places, companies, historical events, or other subjects. "

View File

@ -114,6 +114,7 @@ class LLMInterface(AIMessage):
completion_tokens: int = -1
total_tokens: int = -1
prompt_tokens: int = -1
total_cost: float = 0
logits: list[list[float]] = Field(default_factory=list)
messages: list[AIMessage] = Field(default_factory=list)

View File

@ -2,6 +2,7 @@ import os
import click
import yaml
from trogon import tui
# check if the output is not a .yml file -> raise error
@ -14,6 +15,7 @@ def check_config_format(config):
raise ValueError("config must be yaml format.")
@tui(command="ui", help="Open the terminal UI") # generate the terminal UI
@click.group()
def main():
pass
@ -56,8 +58,10 @@ def export(export_path, output):
@click.option(
"--username",
required=False,
help="Username for the user. If not provided, the promptui will not have "
"authentication.",
help=(
"Username for the user. If not provided, the promptui will not have "
"authentication."
),
)
@click.option(
"--password",

View File

@ -1,4 +1,4 @@
from .base import ChatLLM
from .langchain_based import AzureChatOpenAI
from .langchain_based import AzureChatOpenAI, LCChatMixin
__all__ = ["ChatLLM", "AzureChatOpenAI"]
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin"]

View File

@ -1,4 +1,4 @@
from .base import LLM
from .langchain_based import AzureOpenAI, OpenAI
from .langchain_based import AzureOpenAI, LCCompletionMixin, OpenAI
__all__ = ["LLM", "OpenAI", "AzureOpenAI"]
__all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin"]

View File

@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
# metadata and dependencies
[project]
name = "kotaemon"
version = "0.3.3"
version = "0.3.4"
requires-python = ">= 3.10"
description = "Kotaemon core library for AI development."
dependencies = [
@ -24,6 +24,7 @@ dependencies = [
"cookiecutter",
"click",
"pandas",
"trogon",
]
readme = "README.md"
license = { text = "MIT License" }

View File

@ -3,73 +3,69 @@ from unittest.mock import patch
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.agents.base import AgentType
from kotaemon.agents.langchain import LangchainAgent
from kotaemon.agents.react import ReactAgent
from kotaemon.agents.rewoo import RewooAgent
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool
from kotaemon.agents import (
AgentType,
BaseTool,
GoogleSearchTool,
LangchainAgent,
LLMTool,
ReactAgent,
RewooAgent,
WikipediaTool,
)
from kotaemon.llms import AzureChatOpenAI
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
REWOO_VALID_PLAN = (
"#Plan1: Search for Cinnamon AI company on Google\n"
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
)
REWOO_INVALID_PLAN = (
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
)
def generate_chat_completion_obj(text):
return ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
_openai_chat_completion_responses_rewoo = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
for text in [
(
"#Plan1: Search for Cinnamon AI company on Google\n"
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
),
FINAL_RESPONSE_TEXT,
]
generate_chat_completion_obj(text=text)
for text in [REWOO_VALID_PLAN, FINAL_RESPONSE_TEXT]
]
_openai_chat_completion_responses_rewoo_error = [
generate_chat_completion_obj(text=text)
for text in [REWOO_INVALID_PLAN, FINAL_RESPONSE_TEXT]
]
_openai_chat_completion_responses_react = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
generate_chat_completion_obj(text=text)
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
@ -91,28 +87,7 @@ _openai_chat_completion_responses_react = [
]
_openai_chat_completion_responses_react_langchain_tool = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
generate_chat_completion_obj(text=text)
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
@ -145,6 +120,25 @@ def llm():
)
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_rewoo_error,
)
def test_agent_fail(openai_completion, llm, mock_google_search):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert not response
assert response.status == "failed"
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_rewoo,
@ -156,7 +150,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search):
LLMTool(llm=llm),
]
agent = RewooAgent(llm=llm, plugins=plugins)
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()

View File

@ -110,8 +110,8 @@ class TestInMemoryVectorStore:
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.delete(["3"])
db.save(save_path=tmp_path / "test_save_load_delete.json")
f = open(tmp_path / "test_save_load_delete.json")
data = json.load(f)
with open(tmp_path / "test_save_load_delete.json") as f:
data = json.load(f)
assert (
"1" and "2" in data["text_id_to_ref_doc_id"]
), "save function does not save data completely"
@ -136,8 +136,8 @@ class TestSimpleFileVectorStore:
db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.delete(["3"])
f = open(tmp_path / "test_save_load_delete.json")
data = json.load(f)
with open(tmp_path / "test_save_load_delete.json") as f:
data = json.load(f)
assert (
"1" and "2" in data["text_id_to_ref_doc_id"]
), "save function does not save data completely"