kotaemon/knowledgehub/pipelines/tools/base.py
Tuan Anh Nguyen Dang (Tadashi_Cin) f9fc02a32a [AUR-363, AUR-433, AUR-434] Add Base Tool interface with Wikipedia/Google tools (#30)
* add base Tool

* minor update test_tool

* update test dependency

* update test dependency

* Fix namespace conflict

* update test

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
2023-09-29 10:18:49 +07:00

136 lines
4.4 KiB
Python

from abc import abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import BaseModel
from kotaemon.base import BaseComponent
class ToolException(Exception):
"""An optional exception that tool throws when execution error occurs.
When this exception is thrown, the agent will not stop working,
but will handle the exception according to the handle_tool_error
variable of the tool, and the processing result will be returned
to the agent as observation, and printed in red on the console.
"""
class BaseTool(BaseComponent):
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Description used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description. This will be
input to the prompt of LLM.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
verbose: bool = False
"""Whether to log the tool's progress."""
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False
"""Handle the content of the ToolException thrown."""
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
args_schema = self.args_schema
if isinstance(tool_input, str):
if args_schema is not None:
key_ = next(iter(args_schema.__fields__.keys()))
args_schema.validate({key_: tool_input})
return tool_input
else:
if args_schema is not None:
result = args_schema.parse_obj(tool_input)
return {k: v for k, v in result.dict().items() if k in tool_input}
return tool_input
@abstractmethod
def _run_tool(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Call tool."""
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def _handle_tool_error(self, e: ToolException) -> Any:
"""Handle the content of the ToolException thrown."""
observation = None
if not self.handle_tool_error:
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
return observation
def run_raw(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
# TODO (verbose_): Add logging
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = self._run_tool(*tool_args, **tool_kwargs)
except ToolException as e:
observation = self._handle_tool_error(e)
return observation
else:
return observation
def run_document(self, *args, **kwargs):
pass
def run_batch_raw(self, *args, **kwargs):
pass
def run_batch_document(self, *args, **kwargs):
pass
def is_document(self, *args, **kwargs) -> bool:
"""Tool does not support processing document"""
return False
def is_batch(self, *args, **kwargs) -> bool:
"""Tool does not support processing batch"""
return False
class ComponentTool(BaseTool):
"""
A Tool based on another pipeline / BaseComponent to be used
as its main entry point
"""
component: BaseComponent
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
return self.component(*args, **kwargs)