Allow users to add LLM within the UI (#6)
* Rename AzureChatOpenAI to LCAzureChatOpenAI * Provide vanilla ChatOpenAI and AzureChatOpenAI * Remove the highest accuracy, lowest cost criteria These criteria are unnecessary. The users, not pipeline creators, should choose which LLM to use. Furthermore, it's cumbersome to input this information, really degrades user experience. * Remove the LLM selection in simple reasoning pipeline * Provide a dedicated stream method to generate the output * Return placeholder message to chat if the text is empty
This commit is contained in:
committed by
GitHub
parent
e187e23dd1
commit
a203fc0f7c
@@ -40,16 +40,15 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||
):
|
||||
if config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""):
|
||||
KH_LLMS["azure"] = {
|
||||
"def": {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.llms.AzureChatOpenAI",
|
||||
"temperature": 0,
|
||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
"api_version": config("OPENAI_API_VERSION", default="")
|
||||
or "2024-02-15-preview",
|
||||
"deployment_name": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
|
||||
"request_timeout": 10,
|
||||
"stream": False,
|
||||
"azure_deployment": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
|
||||
"timeout": 20,
|
||||
},
|
||||
"default": False,
|
||||
"accuracy": 5,
|
||||
@@ -57,7 +56,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||
}
|
||||
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
|
||||
KH_EMBEDDINGS["azure"] = {
|
||||
"def": {
|
||||
"spec": {
|
||||
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
|
||||
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
|
||||
@@ -164,5 +163,11 @@ KH_INDICES = [
|
||||
"name": "File",
|
||||
"config": {},
|
||||
"index_type": "ktem.index.file.FileIndex",
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "Sample",
|
||||
"config": {},
|
||||
"index_type": "ktem.index.file.FileIndex",
|
||||
},
|
||||
]
|
||||
|
@@ -3,6 +3,7 @@
|
||||
import logging
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from theflow.settings import settings
|
||||
from theflow.utils.modules import deserialize
|
||||
@@ -48,7 +49,7 @@ class ModelPool:
|
||||
self._default: list[str] = []
|
||||
|
||||
for name, model in conf.items():
|
||||
self._models[name] = deserialize(model["def"], safe=False)
|
||||
self._models[name] = deserialize(model["spec"], safe=False)
|
||||
if model.get("default", False):
|
||||
self._default.append(name)
|
||||
|
||||
@@ -58,11 +59,27 @@ class ModelPool:
|
||||
self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf"))))
|
||||
|
||||
def __getitem__(self, key: str) -> BaseComponent:
|
||||
"""Get model by name"""
|
||||
return self._models[key]
|
||||
|
||||
def __setitem__(self, key: str, value: BaseComponent):
|
||||
"""Set model by name"""
|
||||
self._models[key] = value
|
||||
|
||||
def __delitem__(self, key: str):
|
||||
"""Delete model by name"""
|
||||
del self._models[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if model exists"""
|
||||
return key in self._models
|
||||
|
||||
def get(
|
||||
self, key: str, default: Optional[BaseComponent] = None
|
||||
) -> Optional[BaseComponent]:
|
||||
"""Get model by name with default value"""
|
||||
return self._models.get(key, default)
|
||||
|
||||
def settings(self) -> dict:
|
||||
"""Present model pools option for gradio"""
|
||||
return {
|
||||
@@ -169,4 +186,3 @@ llms = ModelPool("LLMs", settings.KH_LLMS)
|
||||
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
|
||||
reasonings: dict = {}
|
||||
tools = ModelPool("Tools", {})
|
||||
indices = ModelPool("Indices", {})
|
||||
|
@@ -157,10 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
from ktem.components import llms
|
||||
from ktem.llms.manager import llms
|
||||
|
||||
try:
|
||||
reranking_llm = llms.get_lowest_cost_name()
|
||||
reranking_llm = llms.get_default_name()
|
||||
reranking_llm_choices = list(llms.options().keys())
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
0
libs/ktem/ktem/llms/__init__.py
Normal file
0
libs/ktem/ktem/llms/__init__.py
Normal file
36
libs/ktem/ktem/llms/db.py
Normal file
36
libs/ktem/ktem/llms/db.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Type
|
||||
|
||||
from ktem.db.engine import engine
|
||||
from sqlalchemy import JSON, Boolean, Column, String
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from theflow.settings import settings as flowsettings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class BaseLLMTable(Base):
|
||||
"""Base table to store language model"""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
name = Column(String, primary_key=True, unique=True)
|
||||
spec = Column(JSON, default={})
|
||||
default = Column(Boolean, default=False)
|
||||
|
||||
|
||||
_base_llm: Type[BaseLLMTable] = (
|
||||
import_dotted_string(flowsettings.KH_TABLE_LLM, safe=False)
|
||||
if hasattr(flowsettings, "KH_TABLE_LLM")
|
||||
else BaseLLMTable
|
||||
)
|
||||
|
||||
|
||||
class LLMTable(_base_llm): # type: ignore
|
||||
__tablename__ = "llm_table"
|
||||
|
||||
|
||||
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
|
||||
LLMTable.metadata.create_all(engine)
|
191
libs/ktem/ktem/llms/manager.py
Normal file
191
libs/ktem/ktem/llms/manager.py
Normal file
@@ -0,0 +1,191 @@
|
||||
from typing import Optional, Type
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from theflow.settings import settings as flowsettings
|
||||
from theflow.utils.modules import deserialize
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
|
||||
from .db import LLMTable, engine
|
||||
|
||||
|
||||
class LLMManager:
|
||||
"""Represent a pool of models"""
|
||||
|
||||
def __init__(self):
|
||||
self._models: dict[str, BaseComponent] = {}
|
||||
self._info: dict[str, dict] = {}
|
||||
self._default: str = ""
|
||||
self._vendors: list[Type] = []
|
||||
|
||||
if hasattr(flowsettings, "KH_LLMS"):
|
||||
for name, model in flowsettings.KH_LLMS.items():
|
||||
with Session(engine) as session:
|
||||
stmt = select(LLMTable).where(LLMTable.name == name)
|
||||
result = session.execute(stmt)
|
||||
if not result.first():
|
||||
item = LLMTable(
|
||||
name=name,
|
||||
spec=model["spec"],
|
||||
default=model.get("default", False),
|
||||
)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
|
||||
self.load()
|
||||
self.load_vendors()
|
||||
|
||||
def load(self):
|
||||
"""Load the model pool from database"""
|
||||
self._models, self._info, self._defaut = {}, {}, ""
|
||||
with Session(engine) as session:
|
||||
stmt = select(LLMTable)
|
||||
items = session.execute(stmt)
|
||||
|
||||
for (item,) in items:
|
||||
self._models[item.name] = deserialize(item.spec, safe=False)
|
||||
self._info[item.name] = {
|
||||
"name": item.name,
|
||||
"spec": item.spec,
|
||||
"default": item.default,
|
||||
}
|
||||
if item.default:
|
||||
self._default = item.name
|
||||
|
||||
def load_vendors(self):
|
||||
from kotaemon.llms import (
|
||||
AzureChatOpenAI,
|
||||
ChatOpenAI,
|
||||
EndpointChatLLM,
|
||||
LlamaCppChat,
|
||||
)
|
||||
|
||||
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
|
||||
|
||||
def __getitem__(self, key: str) -> BaseComponent:
|
||||
"""Get model by name"""
|
||||
return self._models[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
"""Check if model exists"""
|
||||
return key in self._models
|
||||
|
||||
def get(
|
||||
self, key: str, default: Optional[BaseComponent] = None
|
||||
) -> Optional[BaseComponent]:
|
||||
"""Get model by name with default value"""
|
||||
return self._models.get(key, default)
|
||||
|
||||
def settings(self) -> dict:
|
||||
"""Present model pools option for gradio"""
|
||||
return {
|
||||
"label": "LLM",
|
||||
"choices": list(self._models.keys()),
|
||||
"value": self.get_default_name(),
|
||||
}
|
||||
|
||||
def options(self) -> dict:
|
||||
"""Present a dict of models"""
|
||||
return self._models
|
||||
|
||||
def get_random_name(self) -> str:
|
||||
"""Get the name of random model
|
||||
|
||||
Returns:
|
||||
str: random model name in the pool
|
||||
"""
|
||||
import random
|
||||
|
||||
if not self._models:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
return random.choice(list(self._models.keys()))
|
||||
|
||||
def get_default_name(self) -> str:
|
||||
"""Get the name of default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
str: model name
|
||||
"""
|
||||
if not self._models:
|
||||
raise ValueError("No models in pool")
|
||||
|
||||
if not self._default:
|
||||
return self.get_random_name()
|
||||
|
||||
return self._default
|
||||
|
||||
def get_random(self) -> BaseComponent:
|
||||
"""Get random model"""
|
||||
return self._models[self.get_random_name()]
|
||||
|
||||
def get_default(self) -> BaseComponent:
|
||||
"""Get default model
|
||||
|
||||
In case there is no default model, choose random model from pool. In
|
||||
case there are multiple default models, choose random from them.
|
||||
|
||||
Returns:
|
||||
BaseComponent: model
|
||||
"""
|
||||
return self._models[self.get_default_name()]
|
||||
|
||||
def info(self) -> dict:
|
||||
"""List all models"""
|
||||
return self._info
|
||||
|
||||
def add(self, name: str, spec: dict, default: bool):
|
||||
"""Add a new model to the pool"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
item = LLMTable(name=name, spec=spec, default=default)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to add model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def delete(self, name: str):
|
||||
"""Delete a model from the pool"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
item = session.query(LLMTable).filter_by(name=name).first()
|
||||
session.delete(item)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to delete model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def update(self, name: str, spec: dict, default: bool):
|
||||
"""Update a model in the pool"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
|
||||
if default:
|
||||
# turn all models to non-default
|
||||
session.query(LLMTable).update({"default": False})
|
||||
session.commit()
|
||||
|
||||
item = session.query(LLMTable).filter_by(name=name).first()
|
||||
if not item:
|
||||
raise ValueError(f"Model {name} not found")
|
||||
item.spec = spec
|
||||
item.default = default
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update model {name}: {e}")
|
||||
|
||||
self.load()
|
||||
|
||||
def vendors(self) -> dict:
|
||||
"""Return list of vendors"""
|
||||
return {vendor.__qualname__: vendor for vendor in self._vendors}
|
||||
|
||||
|
||||
llms = LLMManager()
|
318
libs/ktem/ktem/llms/ui.py
Normal file
318
libs/ktem/ktem/llms/ui.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from copy import deepcopy
|
||||
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
import yaml
|
||||
from ktem.app import BasePage
|
||||
|
||||
from .manager import llms
|
||||
|
||||
|
||||
def format_description(cls):
|
||||
params = cls.describe()["params"]
|
||||
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
|
||||
for key, value in params.items():
|
||||
if isinstance(value["auto_callback"], str):
|
||||
continue
|
||||
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
|
||||
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
|
||||
|
||||
|
||||
class LLMManagement(BasePage):
|
||||
def __init__(self, app):
|
||||
self._app = app
|
||||
self.spec_desc_default = (
|
||||
"# Spec description\n\nSelect an LLM to view the spec description."
|
||||
)
|
||||
self.on_building_ui()
|
||||
|
||||
def on_building_ui(self):
|
||||
with gr.Tab(label="View"):
|
||||
self.llm_list = gr.DataFrame(
|
||||
headers=["name", "vendor", "default"],
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self._selected_panel:
|
||||
self.selected_llm_name = gr.Textbox(value="", visible=False)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
self.edit_default = gr.Checkbox(
|
||||
label="Set default",
|
||||
info=(
|
||||
"Set this LLM as default. If no default is set, a "
|
||||
"random LLM will be used."
|
||||
),
|
||||
)
|
||||
self.edit_spec = gr.Textbox(
|
||||
label="Specification",
|
||||
info="Specification of the LLM in YAML format",
|
||||
lines=10,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False) as self._selected_panel_btn:
|
||||
with gr.Column():
|
||||
self.btn_edit_save = gr.Button("Save", min_width=10)
|
||||
with gr.Column():
|
||||
self.btn_delete = gr.Button("Delete", min_width=10)
|
||||
with gr.Row():
|
||||
self.btn_delete_yes = gr.Button(
|
||||
"Confirm delete",
|
||||
variant="primary",
|
||||
visible=False,
|
||||
min_width=10,
|
||||
)
|
||||
self.btn_delete_no = gr.Button(
|
||||
"Cancel", visible=False, min_width=10
|
||||
)
|
||||
with gr.Column():
|
||||
self.btn_close = gr.Button("Close", min_width=10)
|
||||
|
||||
with gr.Column():
|
||||
self.edit_spec_desc = gr.Markdown("# Spec description")
|
||||
|
||||
with gr.Tab(label="Add"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
self.name = gr.Textbox(
|
||||
label="LLM name",
|
||||
info=(
|
||||
"Must be unique. The name will be used to identify the LLM."
|
||||
),
|
||||
)
|
||||
self.llm_choices = gr.Dropdown(
|
||||
label="LLM vendors",
|
||||
info=(
|
||||
"Choose the vendor for the LLM. Each vendor has different "
|
||||
"specification."
|
||||
),
|
||||
)
|
||||
self.spec = gr.Textbox(
|
||||
label="Specification",
|
||||
info="Specification of the LLM in YAML format",
|
||||
)
|
||||
self.default = gr.Checkbox(
|
||||
label="Set default",
|
||||
info=(
|
||||
"Set this LLM as default. This default LLM will be used "
|
||||
"by default across the application."
|
||||
),
|
||||
)
|
||||
self.btn_new = gr.Button("Create LLM")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
self.spec_desc = gr.Markdown(self.spec_desc_default)
|
||||
|
||||
def _on_app_created(self):
|
||||
"""Called when the app is created"""
|
||||
self._app.app.load(
|
||||
self.list_llms,
|
||||
inputs=None,
|
||||
outputs=[self.llm_list],
|
||||
)
|
||||
self._app.app.load(
|
||||
lambda: gr.update(choices=list(llms.vendors().keys())),
|
||||
outputs=[self.llm_choices],
|
||||
)
|
||||
|
||||
def on_llm_vendor_change(self, vendor):
|
||||
vendor = llms.vendors()[vendor]
|
||||
|
||||
required: dict = {}
|
||||
desc = vendor.describe()
|
||||
for key, value in desc["params"].items():
|
||||
if value.get("required", False):
|
||||
required[key] = None
|
||||
|
||||
return yaml.dump(required), format_description(vendor)
|
||||
|
||||
def on_register_events(self):
|
||||
self.llm_choices.select(
|
||||
self.on_llm_vendor_change,
|
||||
inputs=[self.llm_choices],
|
||||
outputs=[self.spec, self.spec_desc],
|
||||
)
|
||||
self.btn_new.click(
|
||||
self.create_llm,
|
||||
inputs=[self.name, self.llm_choices, self.spec, self.default],
|
||||
outputs=None,
|
||||
).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then(
|
||||
lambda: ("", None, "", False, self.spec_desc_default),
|
||||
outputs=[
|
||||
self.name,
|
||||
self.llm_choices,
|
||||
self.spec,
|
||||
self.default,
|
||||
self.spec_desc,
|
||||
],
|
||||
)
|
||||
self.llm_list.select(
|
||||
self.select_llm,
|
||||
inputs=self.llm_list,
|
||||
outputs=[self.selected_llm_name],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.selected_llm_name.change(
|
||||
self.on_selected_llm_change,
|
||||
inputs=[self.selected_llm_name],
|
||||
outputs=[
|
||||
self._selected_panel,
|
||||
self._selected_panel_btn,
|
||||
# delete section
|
||||
self.btn_delete,
|
||||
self.btn_delete_yes,
|
||||
self.btn_delete_no,
|
||||
# edit section
|
||||
self.edit_spec,
|
||||
self.edit_spec_desc,
|
||||
self.edit_default,
|
||||
],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_delete.click(
|
||||
self.on_btn_delete_click,
|
||||
inputs=None,
|
||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_delete_yes.click(
|
||||
self.delete_llm,
|
||||
inputs=[self.selected_llm_name],
|
||||
outputs=[self.selected_llm_name],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_llms,
|
||||
inputs=None,
|
||||
outputs=[self.llm_list],
|
||||
)
|
||||
self.btn_delete_no.click(
|
||||
lambda: (
|
||||
gr.update(visible=True),
|
||||
gr.update(visible=False),
|
||||
gr.update(visible=False),
|
||||
),
|
||||
inputs=None,
|
||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||
show_progress="hidden",
|
||||
)
|
||||
self.btn_edit_save.click(
|
||||
self.save_llm,
|
||||
inputs=[
|
||||
self.selected_llm_name,
|
||||
self.edit_default,
|
||||
self.edit_spec,
|
||||
],
|
||||
show_progress="hidden",
|
||||
).then(
|
||||
self.list_llms,
|
||||
inputs=None,
|
||||
outputs=[self.llm_list],
|
||||
)
|
||||
self.btn_close.click(
|
||||
lambda: "",
|
||||
outputs=[self.selected_llm_name],
|
||||
)
|
||||
|
||||
def create_llm(self, name, choices, spec, default):
|
||||
try:
|
||||
spec = yaml.safe_load(spec)
|
||||
spec["__type__"] = (
|
||||
llms.vendors()[choices].__module__
|
||||
+ "."
|
||||
+ llms.vendors()[choices].__qualname__
|
||||
)
|
||||
|
||||
llms.add(name, spec=spec, default=default)
|
||||
gr.Info(f"LLM {name} created successfully")
|
||||
except Exception as e:
|
||||
gr.Error(f"Failed to create LLM {name}: {e}")
|
||||
|
||||
def list_llms(self):
|
||||
"""List the LLMs"""
|
||||
items = []
|
||||
for item in llms.info().values():
|
||||
record = {}
|
||||
record["name"] = item["name"]
|
||||
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
|
||||
record["default"] = item["default"]
|
||||
items.append(record)
|
||||
|
||||
if items:
|
||||
llm_list = pd.DataFrame.from_records(items)
|
||||
else:
|
||||
llm_list = pd.DataFrame.from_records(
|
||||
[{"name": "-", "vendor": "-", "default": "-"}]
|
||||
)
|
||||
|
||||
return llm_list
|
||||
|
||||
def select_llm(self, llm_list, ev: gr.SelectData):
|
||||
if ev.value == "-" and ev.index[0] == 0:
|
||||
gr.Info("No LLM is loaded. Please add LLM first")
|
||||
return ""
|
||||
|
||||
if not ev.selected:
|
||||
return ""
|
||||
|
||||
return llm_list["name"][ev.index[0]]
|
||||
|
||||
def on_selected_llm_change(self, selected_llm_name):
|
||||
if selected_llm_name == "":
|
||||
_selected_panel = gr.update(visible=False)
|
||||
_selected_panel_btn = gr.update(visible=False)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
edit_spec = gr.update(value="")
|
||||
edit_spec_desc = gr.update(value="")
|
||||
edit_default = gr.update(value=False)
|
||||
else:
|
||||
_selected_panel = gr.update(visible=True)
|
||||
_selected_panel_btn = gr.update(visible=True)
|
||||
btn_delete = gr.update(visible=True)
|
||||
btn_delete_yes = gr.update(visible=False)
|
||||
btn_delete_no = gr.update(visible=False)
|
||||
|
||||
info = deepcopy(llms.info()[selected_llm_name])
|
||||
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
|
||||
vendor = llms.vendors()[vendor_str]
|
||||
|
||||
edit_spec = yaml.dump(info["spec"])
|
||||
edit_spec_desc = format_description(vendor)
|
||||
edit_default = info["default"]
|
||||
|
||||
return (
|
||||
_selected_panel,
|
||||
_selected_panel_btn,
|
||||
btn_delete,
|
||||
btn_delete_yes,
|
||||
btn_delete_no,
|
||||
edit_spec,
|
||||
edit_spec_desc,
|
||||
edit_default,
|
||||
)
|
||||
|
||||
def on_btn_delete_click(self):
|
||||
btn_delete = gr.update(visible=False)
|
||||
btn_delete_yes = gr.update(visible=True)
|
||||
btn_delete_no = gr.update(visible=True)
|
||||
|
||||
return btn_delete, btn_delete_yes, btn_delete_no
|
||||
|
||||
def save_llm(self, selected_llm_name, default, spec):
|
||||
try:
|
||||
spec = yaml.safe_load(spec)
|
||||
spec["__type__"] = llms.info()[selected_llm_name]["spec"]["__type__"]
|
||||
llms.update(selected_llm_name, spec=spec, default=default)
|
||||
gr.Info(f"LLM {selected_llm_name} saved successfully")
|
||||
except Exception as e:
|
||||
gr.Error(f"Failed to save LLM {selected_llm_name}: {e}")
|
||||
|
||||
def delete_llm(self, selected_llm_name):
|
||||
try:
|
||||
llms.delete(selected_llm_name)
|
||||
except Exception as e:
|
||||
gr.Error(f"Failed to delete LLM {selected_llm_name}: {e}")
|
||||
return selected_llm_name
|
||||
|
||||
return ""
|
@@ -1,6 +1,7 @@
|
||||
import gradio as gr
|
||||
from ktem.app import BasePage
|
||||
from ktem.db.models import User, engine
|
||||
from ktem.llms.ui import LLMManagement
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from .user import UserManagement
|
||||
@@ -16,6 +17,9 @@ class AdminPage(BasePage):
|
||||
with gr.Tab("User Management", visible=False) as self.user_management_tab:
|
||||
self.user_management = UserManagement(self._app)
|
||||
|
||||
with gr.Tab("LLM Management") as self.llm_management_tab:
|
||||
self.llm_management = LLMManagement(self._app)
|
||||
|
||||
def on_subscribe_public_events(self):
|
||||
if self._app.f_user_management:
|
||||
self._app.subscribe_event(
|
||||
|
@@ -9,6 +9,8 @@ from ktem.db.models import Conversation, engine
|
||||
from sqlmodel import Session, select
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .chat_panel import ChatPanel
|
||||
from .chat_suggestion import ChatSuggestion
|
||||
from .common import STATE
|
||||
@@ -189,6 +191,7 @@ class ChatPage(BasePage):
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
show_progress="hidden",
|
||||
@@ -220,6 +223,7 @@ class ChatPage(BasePage):
|
||||
self.chat_control.conversation_rn,
|
||||
self.chat_panel.chatbot,
|
||||
self.info_panel,
|
||||
self.chat_state,
|
||||
]
|
||||
+ self._indices_input,
|
||||
show_progress="hidden",
|
||||
@@ -392,7 +396,7 @@ class ChatPage(BasePage):
|
||||
|
||||
return pipeline, reasoning_state
|
||||
|
||||
async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||
def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||
"""Chat function"""
|
||||
chat_input = chat_history[-1][0]
|
||||
chat_history = chat_history[:-1]
|
||||
@@ -403,52 +407,43 @@ class ChatPage(BasePage):
|
||||
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
|
||||
pipeline.set_output_queue(queue)
|
||||
|
||||
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
|
||||
text, refs = "", ""
|
||||
|
||||
len_ref = -1 # for logging purpose
|
||||
msg_placeholder = getattr(
|
||||
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
|
||||
)
|
||||
|
||||
print(msg_placeholder)
|
||||
while True:
|
||||
try:
|
||||
response = queue.get_nowait()
|
||||
except Exception:
|
||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||
yield chat_history + [
|
||||
(chat_input, text or msg_placeholder)
|
||||
], refs, state
|
||||
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
|
||||
|
||||
len_ref = -1 # for logging purpose
|
||||
|
||||
for response in pipeline.stream(chat_input, conversation_id, chat_history):
|
||||
|
||||
if not isinstance(response, Document):
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
queue.task_done()
|
||||
print("Chat completed")
|
||||
break
|
||||
if response.channel is None:
|
||||
continue
|
||||
|
||||
if "output" in response:
|
||||
if response["output"] is None:
|
||||
if response.channel == "chat":
|
||||
if response.content is None:
|
||||
text = ""
|
||||
else:
|
||||
text += response["output"]
|
||||
text += response.content
|
||||
|
||||
if "evidence" in response:
|
||||
if response["evidence"] is None:
|
||||
if response.channel == "info":
|
||||
if response.content is None:
|
||||
refs = ""
|
||||
else:
|
||||
refs += response["evidence"]
|
||||
refs += response.content
|
||||
|
||||
if len(refs) > len_ref:
|
||||
print(f"Len refs: {len(refs)}")
|
||||
len_ref = len(refs)
|
||||
|
||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||
yield chat_history + [(chat_input, text)], refs, state
|
||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
|
||||
|
||||
async def regen_fn(
|
||||
self, conversation_id, chat_history, settings, state, *selecteds
|
||||
):
|
||||
def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||
"""Regen function"""
|
||||
if not chat_history:
|
||||
gr.Warning("Empty chat")
|
||||
@@ -456,12 +451,11 @@ class ChatPage(BasePage):
|
||||
return
|
||||
|
||||
state["app"]["regen"] = True
|
||||
async for chat, refs, state in self.chat_fn(
|
||||
for chat, refs, state in self.chat_fn(
|
||||
conversation_id, chat_history, settings, state, *selecteds
|
||||
):
|
||||
new_state = deepcopy(state)
|
||||
new_state["app"]["regen"] = False
|
||||
yield chat, refs, new_state
|
||||
else:
|
||||
state["app"]["regen"] = False
|
||||
yield chat_history, "", state
|
||||
|
||||
state["app"]["regen"] = False
|
||||
|
@@ -4,10 +4,10 @@ import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
|
||||
import tiktoken
|
||||
from ktem.components import llms
|
||||
from theflow.settings import settings as flowsettings
|
||||
from ktem.llms.manager import llms
|
||||
|
||||
from kotaemon.base import (
|
||||
BaseComponent,
|
||||
@@ -190,10 +190,10 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
lang: the language of the answer. Currently support English and Japanese
|
||||
"""
|
||||
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
|
||||
vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||
vlm_endpoint: str = ""
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
|
||||
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
|
||||
)
|
||||
|
||||
qa_template: str = DEFAULT_QA_TEXT_PROMPT
|
||||
@@ -297,13 +297,95 @@ class AnswerWithContextPipeline(BaseComponent):
|
||||
|
||||
return answer
|
||||
|
||||
def stream( # type: ignore
|
||||
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
|
||||
) -> Generator[Document, None, Document]:
|
||||
"""Answer the question based on the evidence
|
||||
|
||||
def extract_evidence_images(self, evidence: str):
|
||||
"""Util function to extract and isolate images from context/evidence"""
|
||||
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
||||
matches = re.findall(image_pattern, evidence)
|
||||
context = re.sub(image_pattern, "", evidence)
|
||||
return context, matches
|
||||
In addition to the question and the evidence, this method also take into
|
||||
account evidence_mode. The evidence_mode tells which kind of evidence is.
|
||||
The kind of evidence affects:
|
||||
1. How the evidence is represented.
|
||||
2. The prompt to generate the answer.
|
||||
|
||||
By default, the evidence_mode is 0, which means the evidence is plain text with
|
||||
no particular semantic representation. The evidence_mode can be:
|
||||
1. "table": There will be HTML markup telling that there is a table
|
||||
within the evidence.
|
||||
2. "chatbot": There will be HTML markup telling that there is a chatbot.
|
||||
This chatbot is a scenario, extracted from an Excel file, where each
|
||||
row corresponds to an interaction.
|
||||
|
||||
Args:
|
||||
question: the original question posed by user
|
||||
evidence: the text that contain relevant information to answer the question
|
||||
(determined by retrieval pipeline)
|
||||
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
|
||||
"""
|
||||
if evidence_mode == EVIDENCE_MODE_TEXT:
|
||||
prompt_template = PromptTemplate(self.qa_template)
|
||||
elif evidence_mode == EVIDENCE_MODE_TABLE:
|
||||
prompt_template = PromptTemplate(self.qa_table_template)
|
||||
elif evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
prompt_template = PromptTemplate(self.qa_figure_template)
|
||||
else:
|
||||
prompt_template = PromptTemplate(self.qa_chatbot_template)
|
||||
|
||||
images = []
|
||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
# isolate image from evidence
|
||||
evidence, images = self.extract_evidence_images(evidence)
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_template.populate(
|
||||
context=evidence,
|
||||
question=question,
|
||||
lang=self.lang,
|
||||
)
|
||||
|
||||
output = ""
|
||||
if evidence_mode == EVIDENCE_MODE_FIGURE:
|
||||
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
|
||||
output += text
|
||||
yield Document(channel="chat", content=text)
|
||||
else:
|
||||
messages = []
|
||||
if self.system_prompt:
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
messages.append(HumanMessage(content=prompt))
|
||||
|
||||
try:
|
||||
# try streaming first
|
||||
print("Trying LLM streaming")
|
||||
for text in self.llm.stream(messages):
|
||||
output += text.text
|
||||
yield Document(channel="chat", content=text.text)
|
||||
except NotImplementedError:
|
||||
print("Streaming is not supported, falling back to normal processing")
|
||||
output = self.llm(messages).text
|
||||
yield Document(channel="chat", content=output)
|
||||
|
||||
# retrieve the citation
|
||||
citation = None
|
||||
if evidence and self.enable_citation:
|
||||
citation = self.citation_pipeline.invoke(
|
||||
context=evidence, question=question
|
||||
)
|
||||
|
||||
answer = Document(text=output, metadata={"citation": citation})
|
||||
|
||||
return answer
|
||||
|
||||
def extract_evidence_images(self, evidence: str):
|
||||
"""Util function to extract and isolate images from context/evidence"""
|
||||
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
|
||||
matches = re.findall(image_pattern, evidence)
|
||||
context = re.sub(image_pattern, "", evidence)
|
||||
return context, matches
|
||||
|
||||
|
||||
class RewriteQuestionPipeline(BaseComponent):
|
||||
@@ -315,27 +397,19 @@ class RewriteQuestionPipeline(BaseComponent):
|
||||
lang: the language of the answer. Currently support English and Japanese
|
||||
"""
|
||||
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost())
|
||||
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
|
||||
rewrite_template: str = DEFAULT_REWRITE_PROMPT
|
||||
|
||||
lang: str = "English"
|
||||
|
||||
async def run(self, question: str) -> Document: # type: ignore
|
||||
def run(self, question: str) -> Document: # type: ignore
|
||||
prompt_template = PromptTemplate(self.rewrite_template)
|
||||
prompt = prompt_template.populate(question=question, lang=self.lang)
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant"),
|
||||
HumanMessage(content=prompt),
|
||||
]
|
||||
output = ""
|
||||
for text in self.llm(messages):
|
||||
if "content" in text:
|
||||
output += text[1]
|
||||
self.report_output({"chat_input": text[1]})
|
||||
break
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return Document(text=output)
|
||||
return self.llm(messages)
|
||||
|
||||
|
||||
class FullQAPipeline(BaseReasoning):
|
||||
@@ -351,7 +425,7 @@ class FullQAPipeline(BaseReasoning):
|
||||
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
|
||||
use_rewrite: bool = False
|
||||
|
||||
async def run( # type: ignore
|
||||
async def ainvoke( # type: ignore
|
||||
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
||||
) -> Document: # type: ignore
|
||||
import markdown
|
||||
@@ -482,6 +556,132 @@ class FullQAPipeline(BaseReasoning):
|
||||
self.report_output(None)
|
||||
return answer
|
||||
|
||||
def stream( # type: ignore
|
||||
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
|
||||
) -> Generator[Document, None, Document]:
|
||||
import markdown
|
||||
|
||||
docs = []
|
||||
doc_ids = []
|
||||
if self.use_rewrite:
|
||||
message = self.rewrite_pipeline(question=message).text
|
||||
|
||||
for retriever in self.retrievers:
|
||||
for doc in retriever(text=message):
|
||||
if doc.doc_id not in doc_ids:
|
||||
docs.append(doc)
|
||||
doc_ids.append(doc.doc_id)
|
||||
for doc in docs:
|
||||
# TODO: a better approach to show the information
|
||||
text = markdown.markdown(
|
||||
doc.text, extensions=["markdown.extensions.tables"]
|
||||
)
|
||||
yield Document(
|
||||
content=(
|
||||
"<details open>"
|
||||
f"<summary>{doc.metadata['file_name']}</summary>"
|
||||
f"{text}"
|
||||
"</details><br>"
|
||||
),
|
||||
channel="info",
|
||||
)
|
||||
|
||||
evidence_mode, evidence = self.evidence_pipeline(docs).content
|
||||
answer = yield from self.answering_pipeline.stream(
|
||||
question=message,
|
||||
history=history,
|
||||
evidence=evidence,
|
||||
evidence_mode=evidence_mode,
|
||||
conv_id=conv_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# prepare citation
|
||||
spans = defaultdict(list)
|
||||
if answer.metadata["citation"] is not None:
|
||||
for fact_with_evidence in answer.metadata["citation"].answer:
|
||||
for quote in fact_with_evidence.substring_quote:
|
||||
for doc in docs:
|
||||
start_idx = doc.text.find(quote)
|
||||
if start_idx == -1:
|
||||
continue
|
||||
|
||||
end_idx = start_idx + len(quote)
|
||||
|
||||
current_idx = start_idx
|
||||
if "|" not in doc.text[start_idx:end_idx]:
|
||||
spans[doc.doc_id].append(
|
||||
{"start": start_idx, "end": end_idx}
|
||||
)
|
||||
else:
|
||||
while doc.text[current_idx:end_idx].find("|") != -1:
|
||||
match_idx = doc.text[current_idx:end_idx].find("|")
|
||||
spans[doc.doc_id].append(
|
||||
{
|
||||
"start": current_idx,
|
||||
"end": current_idx + match_idx,
|
||||
}
|
||||
)
|
||||
current_idx += match_idx + 2
|
||||
if current_idx > end_idx:
|
||||
break
|
||||
break
|
||||
|
||||
id2docs = {doc.doc_id: doc for doc in docs}
|
||||
lack_evidence = True
|
||||
not_detected = set(id2docs.keys()) - set(spans.keys())
|
||||
yield Document(channel="info", content=None)
|
||||
for id, ss in spans.items():
|
||||
if not ss:
|
||||
not_detected.add(id)
|
||||
continue
|
||||
ss = sorted(ss, key=lambda x: x["start"])
|
||||
text = id2docs[id].text[: ss[0]["start"]]
|
||||
for idx, span in enumerate(ss):
|
||||
text += (
|
||||
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
|
||||
)
|
||||
if idx < len(ss) - 1:
|
||||
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
|
||||
text += id2docs[id].text[ss[-1]["end"] :]
|
||||
text_out = markdown.markdown(
|
||||
text, extensions=["markdown.extensions.tables"]
|
||||
)
|
||||
yield Document(
|
||||
content=(
|
||||
"<details open>"
|
||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||
f"{text_out}"
|
||||
"</details><br>"
|
||||
),
|
||||
channel="info",
|
||||
)
|
||||
lack_evidence = False
|
||||
|
||||
if lack_evidence:
|
||||
yield Document(channel="info", content="No evidence found.\n")
|
||||
|
||||
if not_detected:
|
||||
yield Document(
|
||||
channel="info",
|
||||
content="Retrieved segments without matching evidence:\n",
|
||||
)
|
||||
for id in list(not_detected):
|
||||
text_out = markdown.markdown(
|
||||
id2docs[id].text, extensions=["markdown.extensions.tables"]
|
||||
)
|
||||
yield Document(
|
||||
content=(
|
||||
"<details>"
|
||||
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
|
||||
f"{text_out}"
|
||||
"</details><br>"
|
||||
),
|
||||
channel="info",
|
||||
)
|
||||
|
||||
return answer
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, settings, states, retrievers):
|
||||
"""Get the reasoning pipeline
|
||||
@@ -493,12 +693,9 @@ class FullQAPipeline(BaseReasoning):
|
||||
_id = cls.get_info()["id"]
|
||||
|
||||
pipeline = FullQAPipeline(retrievers=retrievers)
|
||||
pipeline.answering_pipeline.llm = llms[
|
||||
settings[f"reasoning.options.{_id}.main_llm"]
|
||||
]
|
||||
pipeline.answering_pipeline.citation_pipeline.llm = llms[
|
||||
settings[f"reasoning.options.{_id}.citation_llm"]
|
||||
]
|
||||
pipeline.answering_pipeline.llm = llms.get_default()
|
||||
pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default()
|
||||
|
||||
pipeline.answering_pipeline.enable_citation = settings[
|
||||
f"reasoning.options.{_id}.highlight_citation"
|
||||
]
|
||||
@@ -512,7 +709,7 @@ class FullQAPipeline(BaseReasoning):
|
||||
f"reasoning.options.{_id}.qa_prompt"
|
||||
]
|
||||
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||
pipeline.rewrite_pipeline.llm = llms.get_lowest_cost()
|
||||
pipeline.rewrite_pipeline.llm = llms.get_default()
|
||||
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||
settings["reasoning.lang"], "English"
|
||||
)
|
||||
@@ -520,38 +717,12 @@ class FullQAPipeline(BaseReasoning):
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
from ktem.components import llms
|
||||
|
||||
try:
|
||||
citation_llm = llms.get_lowest_cost_name()
|
||||
citation_llm_choices = list(llms.options().keys())
|
||||
main_llm = llms.get_highest_accuracy_name()
|
||||
main_llm_choices = list(llms.options().keys())
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
citation_llm = None
|
||||
citation_llm_choices = []
|
||||
main_llm = None
|
||||
main_llm_choices = []
|
||||
|
||||
return {
|
||||
"highlight_citation": {
|
||||
"name": "Highlight Citation",
|
||||
"value": False,
|
||||
"component": "checkbox",
|
||||
},
|
||||
"citation_llm": {
|
||||
"name": "LLM for citation",
|
||||
"value": citation_llm,
|
||||
"component": "dropdown",
|
||||
"choices": citation_llm_choices,
|
||||
},
|
||||
"main_llm": {
|
||||
"name": "LLM for main generation",
|
||||
"value": main_llm,
|
||||
"component": "dropdown",
|
||||
"choices": main_llm_choices,
|
||||
},
|
||||
"system_prompt": {
|
||||
"name": "System Prompt",
|
||||
"value": "This is a question answering system",
|
||||
|
@@ -7,7 +7,7 @@ from index import ReaderIndexingPipeline
|
||||
from openai.resources.embeddings import Embeddings
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from kotaemon.llms import AzureChatOpenAI
|
||||
from kotaemon.llms import LCAzureChatOpenAI
|
||||
|
||||
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
||||
openai_embedding = json.load(f)
|
||||
@@ -61,7 +61,7 @@ def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
|
||||
assert len(results) == 1
|
||||
|
||||
# create llm
|
||||
llm = AzureChatOpenAI(
|
||||
llm = LCAzureChatOpenAI(
|
||||
openai_api_base="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
|
@@ -2,4 +2,4 @@ from ktem.main import App
|
||||
|
||||
app = App()
|
||||
demo = app.make()
|
||||
demo.queue().launch(favicon_path=app._favicon, inbrowser=True)
|
||||
demo.queue().launch(favicon_path=app._favicon)
|
||||
|
Reference in New Issue
Block a user