kotaemon/knowledgehub/contribs/promptui/config.py
2023-10-24 11:12:22 +07:00

185 lines
5.7 KiB
Python

"""Get config from Pipeline"""
import inspect
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
import yaml
from ...base import BaseComponent
from ...chatbot import BaseChatBot
from .base import DEFAULT_COMPONENT_BY_TYPES
def config_from_value(value: Any) -> dict:
"""Get the config from default value
Args:
value (Any): default value
Returns:
dict: config
"""
component = DEFAULT_COMPONENT_BY_TYPES.get(type(value).__name__, "text")
return {
"component": component,
"params": {
"value": value,
},
}
def handle_param(param: dict) -> dict:
"""Convert param definition into promptui-compliant config
Supported gradio's UI components are (https://www.gradio.app/docs/components)
- CheckBoxGroup: list (multi select)
- DropDown: list (single select)
- File
- Image
- Number: int / float
- Radio: list (single select)
- Slider: int / float
- TextBox: str
"""
params = {}
default = param.get("default", None)
if isinstance(default, str) and default.startswith("{{") and default.endswith("}}"):
default = None
if default is not None:
params["value"] = default
ui_component = param.get("component_ui", "")
if not ui_component:
type_: str = type(default).__name__ if default is not None else ""
ui_component = DEFAULT_COMPONENT_BY_TYPES.get(type_, "text")
return {
"component": ui_component,
"params": params,
}
def handle_node(node: dict) -> dict:
"""Convert node definition into promptui-compliant config"""
config = {}
for name, param_def in node.get("params", {}).items():
if isinstance(param_def["auto_callback"], str):
continue
if param_def.get("ignore_ui", False):
continue
config[name] = handle_param(param_def)
for name, node_def in node.get("nodes", {}).items():
if isinstance(node_def["auto_callback"], str):
continue
if node_def.get("ignore_ui", False):
continue
for key, value in handle_node(node_def["default"]).items():
config[f"{name}.{key}"] = value
for key, value in node_def.get("default_kwargs", {}).items():
config[f"{name}.{key}"] = config_from_value(value)
return config
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
"""Get the input from the pipeline"""
if not hasattr(pipeline, "run_raw"):
return {}
signature = inspect.signature(pipeline.run_raw)
inputs: Dict[str, Dict] = {}
for name, param in signature.parameters.items():
if name in ["self", "args", "kwargs"]:
continue
input_def: Dict[str, Optional[Any]] = {"component": "text"}
default = param.default
if default is param.empty:
inputs[name] = input_def
continue
params = {}
params["value"] = default
type_ = type(default).__name__ if default is not None else None
ui_component = None
if type_ is not None:
ui_component = "text"
input_def["component"] = ui_component
input_def["params"] = params
inputs[name] = input_def
return inputs
def export_pipeline_to_config(
pipeline: Union[BaseComponent, Type[BaseComponent]],
path: Optional[str] = None,
) -> dict:
"""Export a pipeline to a promptui-compliant config dict"""
if inspect.isclass(pipeline):
pipeline = pipeline()
pipeline_def = pipeline.describe()
ui_type = "chat" if isinstance(pipeline, BaseChatBot) else "simple"
if ui_type == "chat":
params = {f".bot.{k}": v for k, v in handle_node(pipeline_def).items()}
params["system_message"] = {"component": "text", "params": {"value": ""}}
outputs = []
if hasattr(pipeline, "_promptui_outputs"):
outputs = pipeline._promptui_outputs
config_obj: dict = {
"ui-type": ui_type,
"params": params,
"inputs": {},
"outputs": outputs,
"logs": {
"full_pipeline": {
"input": {
"step": ".",
"getter": "_get_input",
},
"output": {
"step": ".",
"getter": "_get_output",
},
"preference": {
"step": "preference",
},
}
},
}
else:
outputs = [{"step": ".", "getter": "_get_output", "component": "text"}]
if hasattr(pipeline, "_promptui_outputs"):
outputs = pipeline._promptui_outputs
config_obj = {
"ui-type": ui_type,
"params": handle_node(pipeline_def),
"inputs": handle_input(pipeline),
"outputs": outputs,
"logs": {
"full_pipeline": {
"input": {
"step": ".",
"getter": "_get_input",
},
"output": {
"step": ".",
"getter": "_get_output",
},
},
},
}
config = {f"{pipeline.__module__}.{pipeline.__class__.__name__}": config_obj}
if path is not None:
old_config = config
if Path(path).is_file():
with open(path) as f:
old_config = yaml.safe_load(f)
old_config.update(config)
with open(path, "w") as f:
yaml.safe_dump(old_config, f, sort_keys=False)
return config