refractor agents (#100)
* refractor agents * minor cosmetic, add terminal ui for cli * pump to 0.3.4 * Add temporary path * fix unclose files in tests --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
255
knowledgehub/agents/io/base.py
Normal file
255
knowledgehub/agents/io/base.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
|
||||
|
||||
from pydantic import Extra
|
||||
|
||||
from kotaemon.base import LLMInterface
|
||||
|
||||
|
||||
def check_log():
|
||||
"""
|
||||
Checks if logging has been enabled.
|
||||
:return: True if logging has been enabled, False otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return os.environ.get("LOG_PATH", None) is not None
|
||||
|
||||
|
||||
class AgentType(Enum):
|
||||
"""
|
||||
Enumerated type for agent types.
|
||||
"""
|
||||
|
||||
openai = "openai"
|
||||
openai_multi = "openai_multi"
|
||||
openai_tool = "openai_tool"
|
||||
self_ask = "self_ask"
|
||||
react = "react"
|
||||
rewoo = "rewoo"
|
||||
vanilla = "vanilla"
|
||||
|
||||
|
||||
class BaseScratchPad:
|
||||
"""
|
||||
Base class for output handlers.
|
||||
|
||||
Attributes:
|
||||
-----------
|
||||
logger : logging.Logger
|
||||
The logger object to log messages.
|
||||
|
||||
Methods:
|
||||
--------
|
||||
stop():
|
||||
Stop the output.
|
||||
|
||||
update_status(output: str, **kwargs):
|
||||
Update the status of the output.
|
||||
|
||||
thinking(name: str):
|
||||
Log that a process is thinking.
|
||||
|
||||
done(_all=False):
|
||||
Log that the process is done.
|
||||
|
||||
stream_print(item: str):
|
||||
Not implemented.
|
||||
|
||||
json_print(item: Dict[str, Any]):
|
||||
Log a JSON object.
|
||||
|
||||
panel_print(item: Any, title: str = "Output", stream: bool = False):
|
||||
Log a panel output.
|
||||
|
||||
clear():
|
||||
Not implemented.
|
||||
|
||||
print(content: str, **kwargs):
|
||||
Log arbitrary content.
|
||||
|
||||
format_json(json_obj: str):
|
||||
Format a JSON object.
|
||||
|
||||
debug(content: str, **kwargs):
|
||||
Log a debug message.
|
||||
|
||||
info(content: str, **kwargs):
|
||||
Log an informational message.
|
||||
|
||||
warning(content: str, **kwargs):
|
||||
Log a warning message.
|
||||
|
||||
error(content: str, **kwargs):
|
||||
Log an error message.
|
||||
|
||||
critical(content: str, **kwargs):
|
||||
Log a critical message.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the BaseOutput object.
|
||||
|
||||
"""
|
||||
self.logger = logging
|
||||
self.log = []
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Stop the output.
|
||||
"""
|
||||
|
||||
def update_status(self, output: str, **kwargs):
|
||||
"""
|
||||
Update the status of the output.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(output)
|
||||
|
||||
def thinking(self, name: str):
|
||||
"""
|
||||
Log that a process is thinking.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(f"{name} is thinking...")
|
||||
|
||||
def done(self, _all=False):
|
||||
"""
|
||||
Log that the process is done.
|
||||
"""
|
||||
|
||||
if check_log():
|
||||
self.logger.info("Done")
|
||||
|
||||
def stream_print(self, item: str):
|
||||
"""
|
||||
Stream print.
|
||||
"""
|
||||
|
||||
def json_print(self, item: Dict[str, Any]):
|
||||
"""
|
||||
Log a JSON object.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(json.dumps(item, indent=2))
|
||||
|
||||
def panel_print(self, item: Any, title: str = "Output", stream: bool = False):
|
||||
"""
|
||||
Log a panel output.
|
||||
|
||||
Args:
|
||||
item : Any
|
||||
The item to log.
|
||||
title : str, optional
|
||||
The title of the panel, defaults to "Output".
|
||||
stream : bool, optional
|
||||
"""
|
||||
if not stream:
|
||||
self.log.append(item)
|
||||
if check_log():
|
||||
self.logger.info("-" * 20)
|
||||
self.logger.info(item)
|
||||
self.logger.info("-" * 20)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Not implemented.
|
||||
"""
|
||||
|
||||
def print(self, content: str, **kwargs):
|
||||
"""
|
||||
Log arbitrary content.
|
||||
"""
|
||||
self.log.append(content)
|
||||
if check_log():
|
||||
self.logger.info(content)
|
||||
|
||||
def format_json(self, json_obj: str):
|
||||
"""
|
||||
Format a JSON object.
|
||||
"""
|
||||
formatted_json = json.dumps(json_obj, indent=2)
|
||||
return formatted_json
|
||||
|
||||
def debug(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a debug message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.debug(content, **kwargs)
|
||||
|
||||
def info(self, content: str, **kwargs):
|
||||
"""
|
||||
Log an informational message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.info(content, **kwargs)
|
||||
|
||||
def warning(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a warning message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.warning(content, **kwargs)
|
||||
|
||||
def error(self, content: str, **kwargs):
|
||||
"""
|
||||
Log an error message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.error(content, **kwargs)
|
||||
|
||||
def critical(self, content: str, **kwargs):
|
||||
"""
|
||||
Log a critical message.
|
||||
"""
|
||||
if check_log():
|
||||
self.logger.critical(content, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""Agent's action to take.
|
||||
|
||||
Args:
|
||||
tool: The tool to invoke.
|
||||
tool_input: The input to the tool.
|
||||
log: The log message.
|
||||
"""
|
||||
|
||||
tool: str
|
||||
tool_input: Union[str, dict]
|
||||
log: str
|
||||
|
||||
|
||||
class AgentFinish(NamedTuple):
|
||||
"""Agent's return value when finishing execution.
|
||||
|
||||
Args:
|
||||
return_values: The return values of the agent.
|
||||
log: The log message.
|
||||
"""
|
||||
|
||||
return_values: dict
|
||||
log: str
|
||||
|
||||
|
||||
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
|
||||
"""Output from an agent.
|
||||
|
||||
Args:
|
||||
text: The text output from the agent.
|
||||
agent_type: The type of agent.
|
||||
status: The status after executing the agent.
|
||||
error: The error message if any.
|
||||
"""
|
||||
|
||||
text: str
|
||||
type: str = "agent"
|
||||
agent_type: AgentType
|
||||
status: Literal["finished", "stopped", "failed"]
|
||||
error: Optional[str] = None
|
Reference in New Issue
Block a user