Upgrade the declarative pipeline for cleaner interface (#51)

This commit is contained in:
Nguyen Trung Duc (john)
2023-10-24 11:12:22 +07:00
committed by GitHub
parent aab982ddc4
commit 9035e25666
26 changed files with 365 additions and 169 deletions

View File

@@ -63,19 +63,19 @@ def handle_node(node: dict) -> dict:
"""Convert node definition into promptui-compliant config"""
config = {}
for name, param_def in node.get("params", {}).items():
if isinstance(param_def["default_callback"], str):
if isinstance(param_def["auto_callback"], str):
continue
if param_def.get("ignore_ui", False):
continue
config[name] = handle_param(param_def)
for name, node_def in node.get("nodes", {}).items():
if isinstance(node_def["default_callback"], str):
if isinstance(node_def["auto_callback"], str):
continue
if node_def.get("ignore_ui", False):
continue
for key, value in handle_node(node_def["default"]).items():
config[f"{name}.{key}"] = value
for key, value in node_def["default_kwargs"].items():
for key, value in node_def.get("default_kwargs", {}).items():
config[f"{name}.{key}"] = config_from_value(value)
return config
@@ -124,11 +124,14 @@ def export_pipeline_to_config(
if ui_type == "chat":
params = {f".bot.{k}": v for k, v in handle_node(pipeline_def).items()}
params["system_message"] = {"component": "text", "params": {"value": ""}}
outputs = []
if hasattr(pipeline, "_promptui_outputs"):
outputs = pipeline._promptui_outputs
config_obj: dict = {
"ui-type": ui_type,
"params": params,
"inputs": {},
"outputs": [],
"outputs": outputs,
"logs": {
"full_pipeline": {
"input": {

View File

@@ -61,6 +61,9 @@ def from_log_to_dict(pipeline_cls: Type[BaseComponent], log_config: dict) -> dic
if name not in logged_infos:
logged_infos[name] = [None] * len(dirs)
if step not in progress:
continue
info = progress[step]
if getter:
if getter in allowed_resultlog_callbacks:

View File

@@ -13,9 +13,9 @@ class John(Base):
primary_hue: colors.Color | str = colors.neutral,
secondary_hue: colors.Color | str = colors.neutral,
neutral_hue: colors.Color | str = colors.neutral,
spacing_size: sizes.Size | str = sizes.spacing_lg,
spacing_size: sizes.Size | str = sizes.spacing_sm,
radius_size: sizes.Size | str = sizes.radius_none,
text_size: sizes.Size | str = sizes.text_md,
text_size: sizes.Size | str = sizes.text_sm,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
@@ -79,8 +79,8 @@ class John(Base):
button_cancel_background_fill_hover="*button_primary_background_fill_hover",
button_cancel_text_color="*button_primary_text_color",
# Padding
checkbox_label_padding="*spacing_md",
button_large_padding="*spacing_lg",
checkbox_label_padding="*spacing_sm",
button_large_padding="*spacing_sm",
button_small_padding="*spacing_sm",
# Borders
block_border_width="0px",
@@ -91,5 +91,5 @@ class John(Base):
# Block Labels
block_title_text_weight="600",
block_label_text_weight="600",
block_label_text_size="*text_md",
block_label_text_size="*text_sm",
)

View File

@@ -26,9 +26,9 @@ def build_from_dict(config: Union[str, dict]):
for key, value in config_dict.items():
pipeline_def = import_dotted_string(key, safe=False)
if value["ui-type"] == "chat":
demos.append(build_chat_ui(value, pipeline_def))
demos.append(build_chat_ui(value, pipeline_def).queue())
else:
demos.append(build_pipeline_ui(value, pipeline_def))
demos.append(build_pipeline_ui(value, pipeline_def).queue())
if len(demos) == 1:
demo = demos[0]
else:

View File

@@ -0,0 +1,181 @@
from __future__ import annotations
from typing import Any, AsyncGenerator
import anyio
from gradio import ChatInterface
from gradio.components import IOComponent, get_component_instance
from gradio.events import on
from gradio.helpers import special_args
from gradio.routes import Request
class ChatBlock(ChatInterface):
"""The ChatBlock subclasses ChatInterface to provide extra functionalities:
- Show additional outputs to the chat interface
- Disallow blank user message
"""
def __init__(
self,
*args,
additional_outputs: str | IOComponent | list[str | IOComponent] | None = None,
**kwargs,
):
if additional_outputs:
if not isinstance(additional_outputs, list):
additional_outputs = [additional_outputs]
self.additional_outputs = [
get_component_instance(i) for i in additional_outputs # type: ignore
]
else:
self.additional_outputs = []
super().__init__(*args, **kwargs)
async def _submit_fn(
self,
message: str,
history_with_input: list[list[str | None]],
request: Request,
*args,
) -> tuple[Any, ...]:
input_args = args[: -len(self.additional_outputs)]
output_args = args[-len(self.additional_outputs) :]
if not message:
return history_with_input, history_with_input, *output_args
history = history_with_input[:-1]
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *input_args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
output = []
if self.additional_outputs:
text = response[0]
output = response[1:]
else:
text = response
history.append([message, text])
return history, history, *output
async def _stream_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
) -> AsyncGenerator:
raise NotImplementedError("Stream function not implemented for ChatBlock")
def _display_input(
self, message: str, history: list[list[str | None]]
) -> tuple[list[list[str | None]], list[list[str | None]]]:
"""Stop displaying the input message if the message is a blank string"""
if not message:
return history, history
return super()._display_input(message, history)
def _setup_events(self) -> None:
"""Include additional outputs in the submit event"""
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
submit_triggers = (
[self.textbox.submit, self.submit_btn.click]
if self.submit_btn
else [self.textbox.submit]
)
submit_event = (
on(
submit_triggers,
self._clear_and_save_textbox,
[self.textbox],
[self.textbox, self.saved_input],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events(submit_triggers, submit_event)
if self.retry_btn:
retry_event = (
self.retry_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events([self.retry_btn.click], retry_event)
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
).then(
lambda x: x,
[self.saved_input],
[self.textbox],
api_name=False,
queue=False,
)
if self.clear_btn:
self.clear_btn.click(
lambda: ([], [], None),
None,
[self.chatbot, self.chatbot_state, self.saved_input],
queue=False,
api_name=False,
)
def _setup_api(self) -> None:
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state] + self.additional_outputs,
api_name="chat",
)

View File

@@ -8,6 +8,9 @@ from theflow.storage import storage
from kotaemon.chatbot import ChatConversation
from kotaemon.contribs.promptui.base import get_component
from kotaemon.contribs.promptui.export import export
from kotaemon.contribs.promptui.ui.blocks import ChatBlock
from ..logs import ResultLog
USAGE_INSTRUCTION = """## How to use:
@@ -87,8 +90,10 @@ def construct_chat_ui(
outputs.append(component)
sess = gr.State(value=None)
chatbot = gr.Chatbot(label="Chatbot")
chat = gr.ChatInterface(func_chat, chatbot=chatbot, additional_inputs=[sess])
chatbot = gr.Chatbot(label="Chatbot", show_copy_button=True)
chat = ChatBlock(
func_chat, chatbot=chatbot, additional_inputs=[sess], additional_outputs=outputs
)
param_state = gr.Textbox(interactive=False)
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
@@ -106,6 +111,7 @@ def construct_chat_ui(
chat.saved_input,
param_state,
sess,
*outputs,
],
)
with gr.Accordion(label="End chat", open=False):
@@ -162,6 +168,9 @@ def build_chat_ui(config, pipeline_def):
exported_dir = output_dir.parent / "exported"
exported_dir.mkdir(parents=True, exist_ok=True)
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
def new_chat(*args):
"""Start a new chat function
@@ -190,7 +199,14 @@ def build_chat_ui(config, pipeline_def):
)
gr.Info("New chat session started.")
return [], [], None, param_state_str, session
return (
[],
[],
None,
param_state_str,
session,
*[None] * len(config.get("outputs", [])),
)
def chat(message, history, session, *args):
"""The chat interface
@@ -212,7 +228,18 @@ def build_chat_ui(config, pipeline_def):
"No active chat session. Please set the params and click New chat"
)
return session(message).content
pred = session(message)
text_response = pred.content
additional_outputs = []
for output_def in config.get("outputs", []):
value = session.last_run.logs(output_def["step"])
getter = output_def.get("getter", None)
if getter and getter in allowed_resultlog_callbacks:
value = getattr(resultlog, getter)(value)
additional_outputs.append(value)
return text_response, *additional_outputs
def end_chat(preference: str, save_log: bool, session):
"""End the chat session