Upgrade the declarative pipeline for cleaner interface (#51)

This commit is contained in:
Nguyen Trung Duc (john)
2023-10-24 11:12:22 +07:00
committed by GitHub
parent aab982ddc4
commit 9035e25666
26 changed files with 365 additions and 169 deletions

View File

@@ -26,9 +26,9 @@ def build_from_dict(config: Union[str, dict]):
for key, value in config_dict.items():
pipeline_def = import_dotted_string(key, safe=False)
if value["ui-type"] == "chat":
demos.append(build_chat_ui(value, pipeline_def))
demos.append(build_chat_ui(value, pipeline_def).queue())
else:
demos.append(build_pipeline_ui(value, pipeline_def))
demos.append(build_pipeline_ui(value, pipeline_def).queue())
if len(demos) == 1:
demo = demos[0]
else:

View File

@@ -0,0 +1,181 @@
from __future__ import annotations
from typing import Any, AsyncGenerator
import anyio
from gradio import ChatInterface
from gradio.components import IOComponent, get_component_instance
from gradio.events import on
from gradio.helpers import special_args
from gradio.routes import Request
class ChatBlock(ChatInterface):
"""The ChatBlock subclasses ChatInterface to provide extra functionalities:
- Show additional outputs to the chat interface
- Disallow blank user message
"""
def __init__(
self,
*args,
additional_outputs: str | IOComponent | list[str | IOComponent] | None = None,
**kwargs,
):
if additional_outputs:
if not isinstance(additional_outputs, list):
additional_outputs = [additional_outputs]
self.additional_outputs = [
get_component_instance(i) for i in additional_outputs # type: ignore
]
else:
self.additional_outputs = []
super().__init__(*args, **kwargs)
async def _submit_fn(
self,
message: str,
history_with_input: list[list[str | None]],
request: Request,
*args,
) -> tuple[Any, ...]:
input_args = args[: -len(self.additional_outputs)]
output_args = args[-len(self.additional_outputs) :]
if not message:
return history_with_input, history_with_input, *output_args
history = history_with_input[:-1]
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *input_args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
output = []
if self.additional_outputs:
text = response[0]
output = response[1:]
else:
text = response
history.append([message, text])
return history, history, *output
async def _stream_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
) -> AsyncGenerator:
raise NotImplementedError("Stream function not implemented for ChatBlock")
def _display_input(
self, message: str, history: list[list[str | None]]
) -> tuple[list[list[str | None]], list[list[str | None]]]:
"""Stop displaying the input message if the message is a blank string"""
if not message:
return history, history
return super()._display_input(message, history)
def _setup_events(self) -> None:
"""Include additional outputs in the submit event"""
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
submit_triggers = (
[self.textbox.submit, self.submit_btn.click]
if self.submit_btn
else [self.textbox.submit]
)
submit_event = (
on(
submit_triggers,
self._clear_and_save_textbox,
[self.textbox],
[self.textbox, self.saved_input],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events(submit_triggers, submit_event)
if self.retry_btn:
retry_event = (
self.retry_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events([self.retry_btn.click], retry_event)
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
).then(
lambda x: x,
[self.saved_input],
[self.textbox],
api_name=False,
queue=False,
)
if self.clear_btn:
self.clear_btn.click(
lambda: ([], [], None),
None,
[self.chatbot, self.chatbot_state, self.saved_input],
queue=False,
api_name=False,
)
def _setup_api(self) -> None:
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state] + self.additional_outputs,
api_name="chat",
)

View File

@@ -8,6 +8,9 @@ from theflow.storage import storage
from kotaemon.chatbot import ChatConversation
from kotaemon.contribs.promptui.base import get_component
from kotaemon.contribs.promptui.export import export
from kotaemon.contribs.promptui.ui.blocks import ChatBlock
from ..logs import ResultLog
USAGE_INSTRUCTION = """## How to use:
@@ -87,8 +90,10 @@ def construct_chat_ui(
outputs.append(component)
sess = gr.State(value=None)
chatbot = gr.Chatbot(label="Chatbot")
chat = gr.ChatInterface(func_chat, chatbot=chatbot, additional_inputs=[sess])
chatbot = gr.Chatbot(label="Chatbot", show_copy_button=True)
chat = ChatBlock(
func_chat, chatbot=chatbot, additional_inputs=[sess], additional_outputs=outputs
)
param_state = gr.Textbox(interactive=False)
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
@@ -106,6 +111,7 @@ def construct_chat_ui(
chat.saved_input,
param_state,
sess,
*outputs,
],
)
with gr.Accordion(label="End chat", open=False):
@@ -162,6 +168,9 @@ def build_chat_ui(config, pipeline_def):
exported_dir = output_dir.parent / "exported"
exported_dir.mkdir(parents=True, exist_ok=True)
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
def new_chat(*args):
"""Start a new chat function
@@ -190,7 +199,14 @@ def build_chat_ui(config, pipeline_def):
)
gr.Info("New chat session started.")
return [], [], None, param_state_str, session
return (
[],
[],
None,
param_state_str,
session,
*[None] * len(config.get("outputs", [])),
)
def chat(message, history, session, *args):
"""The chat interface
@@ -212,7 +228,18 @@ def build_chat_ui(config, pipeline_def):
"No active chat session. Please set the params and click New chat"
)
return session(message).content
pred = session(message)
text_response = pred.content
additional_outputs = []
for output_def in config.get("outputs", []):
value = session.last_run.logs(output_def["step"])
getter = output_def.get("getter", None)
if getter and getter in allowed_resultlog_callbacks:
value = getattr(resultlog, getter)(value)
additional_outputs.append(value)
return text_response, *additional_outputs
def end_chat(preference: str, save_log: bool, session):
"""End the chat session