kotaemon/libs/ktem/ktem/embeddings/manager.py
2024-09-06 18:18:19 +07:00

208 lines
6.2 KiB
Python

from typing import Optional, Type
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from kotaemon.embeddings.base import BaseEmbeddings
from .db import EmbeddingTable, engine
class EmbeddingManager:
"""Represent a pool of models"""
def __init__(self):
self._models: dict[str, BaseEmbeddings] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
# populate the pool if empty
if hasattr(flowsettings, "KH_EMBEDDINGS"):
with Session(engine) as sess:
count = sess.query(EmbeddingTable).count()
if not count:
for name, model in flowsettings.KH_EMBEDDINGS.items():
self.add(
name=name,
spec=model["spec"],
default=model.get("default", False),
)
self.load()
self.load_vendors()
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._default = {}, {}, ""
with Session(engine) as sess:
stmt = select(EmbeddingTable)
items = sess.execute(stmt)
for (item,) in items:
self._models[item.name] = deserialize(item.spec, safe=False)
self._info[item.name] = {
"name": item.name,
"spec": item.spec,
"default": item.default,
}
if item.default:
self._default = item.name
def load_vendors(self):
from kotaemon.embeddings import (
AzureOpenAIEmbeddings,
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
OpenAIEmbeddings,
)
self._vendors = [
AzureOpenAIEmbeddings,
OpenAIEmbeddings,
FastEmbedEmbeddings,
LCCohereEmbeddings,
LCHuggingFaceEmbeddings,
]
def __getitem__(self, key: str) -> BaseEmbeddings:
"""Get model by name"""
return self._models[key]
def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models
def get(
self, key: str, default: Optional[BaseEmbeddings] = None
) -> Optional[BaseEmbeddings]:
"""Get model by name with default value"""
return self._models.get(key, default)
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
"label": "Embedding",
"choices": list(self._models.keys()),
"value": self.get_default_name(),
}
def options(self) -> dict:
"""Present a dict of models"""
return self._models
def get_random_name(self) -> str:
"""Get the name of random model
Returns:
str: random model name in the pool
"""
import random
if not self._models:
raise ValueError("No models in pool")
return random.choice(list(self._models.keys()))
def get_default_name(self) -> str:
"""Get the name of default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
str: model name
"""
if not self._models:
raise ValueError("No models in pool")
if not self._default:
return self.get_random_name()
return self._default
def get_random(self) -> BaseEmbeddings:
"""Get random model"""
return self._models[self.get_random_name()]
def get_default(self) -> BaseEmbeddings:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseEmbeddings: model
"""
return self._models[self.get_default_name()]
def info(self) -> dict:
"""List all models"""
return self._info
def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
if not name:
raise ValueError("Name must not be empty")
try:
with Session(engine) as sess:
if default:
# turn all models to non-default
sess.query(EmbeddingTable).update({"default": False})
sess.commit()
item = EmbeddingTable(name=name, spec=spec, default=default)
sess.add(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to add model {name}: {e}")
self.load()
def delete(self, name: str):
"""Delete a model from the pool"""
try:
with Session(engine) as sess:
item = sess.query(EmbeddingTable).filter_by(name=name).first()
sess.delete(item)
sess.commit()
except Exception as e:
raise ValueError(f"Failed to delete model {name}: {e}")
self.load()
def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool"""
if not name:
raise ValueError("Name must not be empty")
try:
with Session(engine) as sess:
if default:
# turn all models to non-default
sess.query(EmbeddingTable).update({"default": False})
sess.commit()
item = sess.query(EmbeddingTable).filter_by(name=name).first()
if not item:
raise ValueError(f"Model {name} not found")
item.spec = spec
item.default = default
sess.commit()
except Exception as e:
raise ValueError(f"Failed to update model {name}: {e}")
self.load()
def vendors(self) -> dict:
"""Return list of vendors"""
return {vendor.__qualname__: vendor for vendor in self._vendors}
embedding_models_manager = EmbeddingManager()