* fix: update .env.example * feat: add SSO login * fix: update flowsetting * fix: add requirement * fix: refine UI * fix: update group id-based operation * fix: improve citation logics * fix: UI enhancement * fix: user_id to string in models * fix: improve chat suggestion UI and flow * fix: improve group id handling * fix: improve chat suggestion * fix: secure download for single file * fix: file limiting in docstore * fix: improve chat suggestion logics & language conform * feat: add markmap and select text to highlight function * fix: update Dockerfile * fix: user id auto generate * fix: default user id * feat: add demo mode * fix: update flowsetting * fix: revise default params for demo * feat: sso_app alternative * feat: sso login demo * feat: demo specific customization * feat: add login using API key * fix: disable key-based login * fix: optimize duplicate upload * fix: gradio routing * fix: disable arm build for demo * fix: revise full-text search js logic * feat: add rate limit * fix: update Dockerfile with new launch script * fix: update Dockerfile * fix: update Dockerignore * fix: update ratelimit logic * fix: user_id in user management page * fix: rename conv logic * feat: update demo hint * fix: minor fix * fix: highlight on long PDF load * feat: add HF paper list * fix: update HF papers load logic * feat: fly config * fix: update fly config * fix: update paper list pull api * fix: minor update root routing * fix: minor update root routing * fix: simplify login flow & paper list UI * feat: add paper recommendation * fix: update Dockerfile * fix: update Dockerfile * fix: update default model * feat: add long context Ollama through LCOllama * feat: espose Gradio share to env * fix: revert customized changes * fix: list group at app load * fix: relocate share conv button * fix: update launch script * fix: update Docker CI * feat: add Ollama model selection at first setup * docs: update README
225 lines
6.7 KiB
Python
225 lines
6.7 KiB
Python
from typing import Optional, Type, overload
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
from theflow.settings import settings as flowsettings
|
|
from theflow.utils.modules import deserialize, import_dotted_string
|
|
|
|
from kotaemon.llms import ChatLLM
|
|
|
|
from .db import LLMTable, engine
|
|
|
|
|
|
class LLMManager:
|
|
"""Represent a pool of models"""
|
|
|
|
def __init__(self):
|
|
self._models: dict[str, ChatLLM] = {}
|
|
self._info: dict[str, dict] = {}
|
|
self._default: str = ""
|
|
self._vendors: list[Type] = []
|
|
|
|
if hasattr(flowsettings, "KH_LLMS"):
|
|
for name, model in flowsettings.KH_LLMS.items():
|
|
with Session(engine) as session:
|
|
stmt = select(LLMTable).where(LLMTable.name == name)
|
|
result = session.execute(stmt)
|
|
if not result.first():
|
|
item = LLMTable(
|
|
name=name,
|
|
spec=model["spec"],
|
|
default=model.get("default", False),
|
|
)
|
|
session.add(item)
|
|
session.commit()
|
|
|
|
self.load()
|
|
self.load_vendors()
|
|
|
|
def load(self):
|
|
"""Load the model pool from database"""
|
|
self._models, self._info, self._default = {}, {}, ""
|
|
with Session(engine) as session:
|
|
stmt = select(LLMTable)
|
|
items = session.execute(stmt)
|
|
|
|
for (item,) in items:
|
|
self._models[item.name] = deserialize(item.spec, safe=False)
|
|
self._info[item.name] = {
|
|
"name": item.name,
|
|
"spec": item.spec,
|
|
"default": item.default,
|
|
}
|
|
if item.default:
|
|
self._default = item.name
|
|
|
|
def load_vendors(self):
|
|
from kotaemon.llms import (
|
|
AzureChatOpenAI,
|
|
ChatOpenAI,
|
|
LCAnthropicChat,
|
|
LCCohereChat,
|
|
LCGeminiChat,
|
|
LCOllamaChat,
|
|
LlamaCppChat,
|
|
)
|
|
|
|
self._vendors = [
|
|
ChatOpenAI,
|
|
AzureChatOpenAI,
|
|
LCAnthropicChat,
|
|
LCGeminiChat,
|
|
LCCohereChat,
|
|
LCOllamaChat,
|
|
LlamaCppChat,
|
|
]
|
|
|
|
for extra_vendor in getattr(flowsettings, "KH_LLM_EXTRA_VENDORS", []):
|
|
self._vendors.append(import_dotted_string(extra_vendor, safe=False))
|
|
|
|
def __getitem__(self, key: str) -> ChatLLM:
|
|
"""Get model by name"""
|
|
return self._models[key]
|
|
|
|
def __contains__(self, key: str) -> bool:
|
|
"""Check if model exists"""
|
|
return key in self._models
|
|
|
|
@overload
|
|
def get(self, key: str, default: None) -> Optional[ChatLLM]:
|
|
...
|
|
|
|
@overload
|
|
def get(self, key: str, default: ChatLLM) -> ChatLLM:
|
|
...
|
|
|
|
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
|
|
"""Get model by name with default value"""
|
|
return self._models.get(key, default)
|
|
|
|
def settings(self) -> dict:
|
|
"""Present model pools option for gradio"""
|
|
return {
|
|
"label": "LLM",
|
|
"choices": list(self._models.keys()),
|
|
"value": self.get_default_name(),
|
|
}
|
|
|
|
def options(self) -> dict:
|
|
"""Present a dict of models"""
|
|
return self._models
|
|
|
|
def get_random_name(self) -> str:
|
|
"""Get the name of random model
|
|
|
|
Returns:
|
|
str: random model name in the pool
|
|
"""
|
|
import random
|
|
|
|
if not self._models:
|
|
raise ValueError("No models in pool")
|
|
|
|
return random.choice(list(self._models.keys()))
|
|
|
|
def get_default_name(self) -> str:
|
|
"""Get the name of default model
|
|
|
|
In case there is no default model, choose random model from pool. In
|
|
case there are multiple default models, choose random from them.
|
|
|
|
Returns:
|
|
str: model name
|
|
"""
|
|
if not self._models:
|
|
raise ValueError("No models in pool")
|
|
|
|
if not self._default:
|
|
return self.get_random_name()
|
|
|
|
return self._default
|
|
|
|
def get_random(self) -> ChatLLM:
|
|
"""Get random model"""
|
|
return self._models[self.get_random_name()]
|
|
|
|
def get_default(self) -> ChatLLM:
|
|
"""Get default model
|
|
|
|
In case there is no default model, choose random model from pool. In
|
|
case there are multiple default models, choose random from them.
|
|
|
|
Returns:
|
|
ChatLLM: model
|
|
"""
|
|
return self._models[self.get_default_name()]
|
|
|
|
def info(self) -> dict:
|
|
"""List all models"""
|
|
return self._info
|
|
|
|
def add(self, name: str, spec: dict, default: bool):
|
|
"""Add a new model to the pool"""
|
|
name = name.strip()
|
|
if not name:
|
|
raise ValueError("Name must not be empty")
|
|
|
|
try:
|
|
with Session(engine) as session:
|
|
|
|
if default:
|
|
# turn all models to non-default
|
|
session.query(LLMTable).update({"default": False})
|
|
session.commit()
|
|
|
|
item = LLMTable(name=name, spec=spec, default=default)
|
|
session.add(item)
|
|
session.commit()
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to add model {name}: {e}")
|
|
|
|
self.load()
|
|
|
|
def delete(self, name: str):
|
|
"""Delete a model from the pool"""
|
|
try:
|
|
with Session(engine) as session:
|
|
item = session.query(LLMTable).filter_by(name=name).first()
|
|
session.delete(item)
|
|
session.commit()
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to delete model {name}: {e}")
|
|
|
|
self.load()
|
|
|
|
def update(self, name: str, spec: dict, default: bool):
|
|
"""Update a model in the pool"""
|
|
if not name:
|
|
raise ValueError("Name must not be empty")
|
|
|
|
try:
|
|
with Session(engine) as session:
|
|
|
|
if default:
|
|
# turn all models to non-default
|
|
session.query(LLMTable).update({"default": False})
|
|
session.commit()
|
|
|
|
item = session.query(LLMTable).filter_by(name=name).first()
|
|
if not item:
|
|
raise ValueError(f"Model {name} not found")
|
|
item.spec = spec
|
|
item.default = default
|
|
session.commit()
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to update model {name}: {e}")
|
|
|
|
self.load()
|
|
|
|
def vendors(self) -> dict:
|
|
"""Return list of vendors"""
|
|
return {vendor.__qualname__: vendor for vendor in self._vendors}
|
|
|
|
|
|
llms = LLMManager()
|