kotaemon/libs/ktem/ktem/rerankings/ui.py
KennyWu 53530e296f
feat: support TEI embedding service, configurable reranking model (#287)
* feat: add support for TEI embedding service, allow reranking model to be configurable.

Signed-off-by: Kennywu <jdlow@live.cn>

* fix: add cohere default reranking model

* fix: comfort pre-commit

---------

Signed-off-by: Kennywu <jdlow@live.cn>
Co-authored-by: wujiaye <wujiaye@bluemoon.com.cn>
Co-authored-by: Tadashi <tadashi@cinnamon.is>
2024-09-30 22:00:00 +07:00

391 lines
15 KiB
Python

from copy import deepcopy
import gradio as gr
import pandas as pd
import yaml
from ktem.app import BasePage
from ktem.utils.file import YAMLNoDateSafeLoader
from theflow.utils.modules import deserialize
from .manager import reranking_models_manager
def format_description(cls):
params = cls.describe()["params"]
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
for key, value in params.items():
if isinstance(value["auto_callback"], str):
continue
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
class RerankingManagement(BasePage):
def __init__(self, app):
self._app = app
self.spec_desc_default = (
"# Spec description\n\nSelect a model to view the spec description."
)
self.on_building_ui()
def on_building_ui(self):
with gr.Tab(label="View"):
self.rerank_list = gr.DataFrame(
headers=["name", "vendor", "default"],
interactive=False,
)
with gr.Column(visible=False) as self._selected_panel:
self.selected_rerank_name = gr.Textbox(value="", visible=False)
with gr.Row():
with gr.Column():
self.edit_default = gr.Checkbox(
label="Set default",
info=(
"Set this Reranking model as default. This default "
"Reranking will be used by other components by default "
"if no Reranking is specified for such components."
),
)
self.edit_spec = gr.Textbox(
label="Specification",
info="Specification of the Embedding model in YAML format",
lines=10,
)
with gr.Accordion(
label="Test connection", visible=False, open=False
) as self._check_connection_panel:
with gr.Row():
with gr.Column(scale=4):
self.connection_logs = gr.HTML(
"Logs",
)
with gr.Column(scale=1):
self.btn_test_connection = gr.Button("Test")
with gr.Row(visible=False) as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button(
"Save", min_width=10, variant="primary"
)
with gr.Column():
self.btn_delete = gr.Button(
"Delete", min_width=10, variant="stop"
)
with gr.Row():
self.btn_delete_yes = gr.Button(
"Confirm Delete",
variant="stop",
visible=False,
min_width=10,
)
self.btn_delete_no = gr.Button(
"Cancel", visible=False, min_width=10
)
with gr.Column():
self.btn_close = gr.Button("Close", min_width=10)
with gr.Column():
self.edit_spec_desc = gr.Markdown("# Spec description")
with gr.Tab(label="Add"):
with gr.Row():
with gr.Column(scale=2):
self.name = gr.Textbox(
label="Name",
info=(
"Must be unique and non-empty. "
"The name will be used to identify the reranking model."
),
)
self.rerank_choices = gr.Dropdown(
label="Vendors",
info=(
"Choose the vendor of the Reranking model. Each vendor "
"has different specification."
),
)
self.spec = gr.Textbox(
label="Specification",
info="Specification of the Embedding model in YAML format.",
)
self.default = gr.Checkbox(
label="Set default",
info=(
"Set this Reranking model as default. This default "
"Reranking will be used by other components by default "
"if no Reranking is specified for such components."
),
)
self.btn_new = gr.Button("Add", variant="primary")
with gr.Column(scale=3):
self.spec_desc = gr.Markdown(self.spec_desc_default)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
self.list_rerankings,
inputs=[],
outputs=[self.rerank_list],
)
self._app.app.load(
lambda: gr.update(choices=list(reranking_models_manager.vendors().keys())),
outputs=[self.rerank_choices],
)
def on_rerank_vendor_change(self, vendor):
vendor = reranking_models_manager.vendors()[vendor]
required: dict = {}
desc = vendor.describe()
for key, value in desc["params"].items():
if value.get("required", False):
required[key] = value.get("default", None)
return yaml.dump(required), format_description(vendor)
def on_register_events(self):
self.rerank_choices.select(
self.on_rerank_vendor_change,
inputs=[self.rerank_choices],
outputs=[self.spec, self.spec_desc],
)
self.btn_new.click(
self.create_rerank,
inputs=[self.name, self.rerank_choices, self.spec, self.default],
outputs=None,
).success(self.list_rerankings, inputs=[], outputs=[self.rerank_list]).success(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
self.rerank_choices,
self.spec,
self.default,
self.spec_desc,
],
)
self.rerank_list.select(
self.select_rerank,
inputs=self.rerank_list,
outputs=[self.selected_rerank_name],
show_progress="hidden",
)
self.selected_rerank_name.change(
self.on_selected_rerank_change,
inputs=[self.selected_rerank_name],
outputs=[
self._selected_panel,
self._selected_panel_btn,
# delete section
self.btn_delete,
self.btn_delete_yes,
self.btn_delete_no,
# edit section
self.edit_spec,
self.edit_spec_desc,
self.edit_default,
self._check_connection_panel,
],
show_progress="hidden",
).success(lambda: gr.update(value=""), outputs=[self.connection_logs])
self.btn_delete.click(
self.on_btn_delete_click,
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_delete_yes.click(
self.delete_rerank,
inputs=[self.selected_rerank_name],
outputs=[self.selected_rerank_name],
show_progress="hidden",
).then(
self.list_rerankings,
inputs=[],
outputs=[self.rerank_list],
)
self.btn_delete_no.click(
lambda: (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
),
inputs=[],
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_edit_save.click(
self.save_rerank,
inputs=[
self.selected_rerank_name,
self.edit_default,
self.edit_spec,
],
show_progress="hidden",
).then(
self.list_rerankings,
inputs=[],
outputs=[self.rerank_list],
)
self.btn_close.click(lambda: "", outputs=[self.selected_rerank_name])
self.btn_test_connection.click(
self.check_connection,
inputs=[self.selected_rerank_name, self.edit_spec],
outputs=[self.connection_logs],
)
def create_rerank(self, name, choices, spec, default):
try:
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
spec["__type__"] = (
reranking_models_manager.vendors()[choices].__module__
+ "."
+ reranking_models_manager.vendors()[choices].__qualname__
)
reranking_models_manager.add(name, spec=spec, default=default)
gr.Info(f'Create Reranking model "{name}" successfully')
except Exception as e:
raise gr.Error(f"Failed to create Reranking model {name}: {e}")
def list_rerankings(self):
"""List the Reranking models"""
items = []
for item in reranking_models_manager.info().values():
record = {}
record["name"] = item["name"]
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
record["default"] = item["default"]
items.append(record)
if items:
rerank_list = pd.DataFrame.from_records(items)
else:
rerank_list = pd.DataFrame.from_records(
[{"name": "-", "vendor": "-", "default": "-"}]
)
return rerank_list
def select_rerank(self, rerank_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No reranking model is loaded. Please add first")
return ""
if not ev.selected:
return ""
return rerank_list["name"][ev.index[0]]
def on_selected_rerank_change(self, selected_rerank_name):
if selected_rerank_name == "":
_check_connection_panel = gr.update(visible=False)
_selected_panel = gr.update(visible=False)
_selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
edit_spec = gr.update(value="")
edit_spec_desc = gr.update(value="")
edit_default = gr.update(value=False)
else:
_check_connection_panel = gr.update(visible=True)
_selected_panel = gr.update(visible=True)
_selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
vendor = reranking_models_manager.vendors()[vendor_str]
edit_spec = yaml.dump(info["spec"])
edit_spec_desc = format_description(vendor)
edit_default = info["default"]
return (
_selected_panel,
_selected_panel_btn,
btn_delete,
btn_delete_yes,
btn_delete_no,
edit_spec,
edit_spec_desc,
edit_default,
_check_connection_panel,
)
def on_btn_delete_click(self):
btn_delete = gr.update(visible=False)
btn_delete_yes = gr.update(visible=True)
btn_delete_no = gr.update(visible=True)
return btn_delete, btn_delete_yes, btn_delete_no
def check_connection(self, selected_rerank_name, selected_spec):
log_content: str = ""
try:
log_content += f"- Testing model: {selected_rerank_name}<br>"
yield log_content
# Parse content & init model
info = deepcopy(reranking_models_manager.info()[selected_rerank_name])
# Parse content & create dummy response
spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader)
info["spec"].update(spec)
rerank = deserialize(info["spec"], safe=False)
if rerank is None:
raise Exception(f"Can not found model: {selected_rerank_name}")
log_content += "- Sending a message ([`Hello`], `Hi`)<br>"
yield log_content
_ = rerank(["Hello"], "Hi")
log_content += (
"<mark style='background: green; color: white'>- Connection success. "
"</mark><br>"
)
yield log_content
gr.Info(f"Embedding {selected_rerank_name} connect successfully")
except Exception as e:
print(e)
log_content += (
f"<mark style='color: yellow; background: red'>- Connection failed. "
f"Got error:\n {str(e)}</mark>"
)
yield log_content
return log_content
def save_rerank(self, selected_rerank_name, default, spec):
try:
spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader)
spec["__type__"] = reranking_models_manager.info()[selected_rerank_name][
"spec"
]["__type__"]
reranking_models_manager.update(
selected_rerank_name, spec=spec, default=default
)
gr.Info(f'Save Reranking model "{selected_rerank_name}" successfully')
except Exception as e:
gr.Error(f'Failed to save Embedding model "{selected_rerank_name}": {e}')
def delete_rerank(self, selected_rerank_name):
try:
reranking_models_manager.delete(selected_rerank_name)
except Exception as e:
gr.Error(f'Failed to delete Reranking model "{selected_rerank_name}": {e}')
return selected_rerank_name
return ""