diff --git a/flowsettings.py b/flowsettings.py index bb2cb7a..e9f9217 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -92,6 +92,7 @@ KH_VECTORSTORE = { } KH_LLMS = {} KH_EMBEDDINGS = {} +KH_RERANKINGS = {} # populate options from config if config("AZURE_OPENAI_API_KEY", default="") and config( @@ -212,7 +213,7 @@ KH_LLMS["cohere"] = { "spec": { "__type__": "kotaemon.llms.chats.LCCohereChat", "model_name": "command-r-plus-08-2024", - "api_key": "your-key", + "api_key": config("COHERE_API_KEY", default="your-key"), }, "default": False, } @@ -222,7 +223,7 @@ KH_EMBEDDINGS["cohere"] = { "spec": { "__type__": "kotaemon.embeddings.LCCohereEmbeddings", "model": "embed-multilingual-v3.0", - "cohere_api_key": "your-key", + "cohere_api_key": config("COHERE_API_KEY", default="your-key"), "user_agent": "default", }, "default": False, @@ -235,6 +236,16 @@ KH_EMBEDDINGS["cohere"] = { # "default": False, # } +# default reranking models +KH_RERANKINGS["cohere"] = { + "spec": { + "__type__": "kotaemon.rerankings.CohereReranking", + "model_name": "rerank-multilingual-v2.0", + "cohere_api_key": config("COHERE_API_KEY", default="your-key"), + }, + "default": True, +} + KH_REASONINGS = [ "ktem.reasoning.simple.FullQAPipeline", "ktem.reasoning.simple.FullDecomposeQAPipeline", diff --git a/libs/kotaemon/kotaemon/embeddings/__init__.py b/libs/kotaemon/kotaemon/embeddings/__init__.py index af01ecc..92b3d1f 100644 --- a/libs/kotaemon/kotaemon/embeddings/__init__.py +++ b/libs/kotaemon/kotaemon/embeddings/__init__.py @@ -8,10 +8,12 @@ from .langchain_based import ( LCOpenAIEmbeddings, ) from .openai import AzureOpenAIEmbeddings, OpenAIEmbeddings +from .tei_endpoint_embed import TeiEndpointEmbeddings __all__ = [ "BaseEmbeddings", "EndpointEmbeddings", + "TeiEndpointEmbeddings", "LCOpenAIEmbeddings", "LCAzureOpenAIEmbeddings", "LCCohereEmbeddings", diff --git a/libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py b/libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py new file mode 100644 index 0000000..6a436e8 --- /dev/null +++ b/libs/kotaemon/kotaemon/embeddings/tei_endpoint_embed.py @@ -0,0 +1,105 @@ +import aiohttp +import requests + +from kotaemon.base import Document, DocumentWithEmbedding, Param + +from .base import BaseEmbeddings + +session = requests.session() + + +class TeiEndpointEmbeddings(BaseEmbeddings): + """An Embeddings component that uses an + TEI (Text-Embedding-Inference) API compatible endpoint. + + Ref: https://github.com/huggingface/text-embeddings-inference + + Attributes: + endpoint_url (str): The url of an TEI + (Text-Embedding-Inference) API compatible endpoint. + normalize (bool): Whether to normalize embeddings to unit length. + truncate (bool): Whether to truncate embeddings + to a fixed/default length. + """ + + endpoint_url: str = Param(None, help="TEI embedding service api base URL") + normalize: bool = Param( + True, + help="Normalize embeddings to unit length", + ) + truncate: bool = Param( + True, + help="Truncate embeddings to a fixed/default length", + ) + + async def client_(self, inputs: list[str]): + async with aiohttp.ClientSession() as session: + async with session.post( + url=self.endpoint_url, + json={ + "inputs": inputs, + "normalize": self.normalize, + "truncate": self.truncate, + }, + ) as resp: + embeddings = await resp.json() + return embeddings + + async def ainvoke( + self, text: str | list[str] | Document | list[Document], *args, **kwargs + ) -> list[DocumentWithEmbedding]: + if not isinstance(text, list): + text = [text] + text = self.prepare_input(text) + + outputs = [] + batch_size = 6 + num_batch = max(len(text) // batch_size, 1) + for i in range(num_batch): + if i == num_batch - 1: + mini_batch = text[batch_size * i :] + else: + mini_batch = text[batch_size * i : batch_size * (i + 1)] + mini_batch = [x.content for x in mini_batch] + embeddings = await self.client_(mini_batch) # type: ignore + outputs.extend( + [ + DocumentWithEmbedding(content=doc, embedding=embedding) + for doc, embedding in zip(mini_batch, embeddings) + ] + ) + + return outputs + + def invoke( + self, text: str | list[str] | Document | list[Document], *args, **kwargs + ) -> list[DocumentWithEmbedding]: + if not isinstance(text, list): + text = [text] + + text = self.prepare_input(text) + + outputs = [] + batch_size = 6 + num_batch = max(len(text) // batch_size, 1) + for i in range(num_batch): + if i == num_batch - 1: + mini_batch = text[batch_size * i :] + else: + mini_batch = text[batch_size * i : batch_size * (i + 1)] + mini_batch = [x.content for x in mini_batch] + embeddings = session.post( + url=self.endpoint_url, + json={ + "inputs": mini_batch, + "normalize": self.normalize, + "truncate": self.truncate, + }, + ).json() + outputs.extend( + [ + DocumentWithEmbedding(content=doc, embedding=embedding) + for doc, embedding in zip(mini_batch, embeddings) + ] + ) + return outputs diff --git a/libs/kotaemon/kotaemon/indices/rankings/cohere.py b/libs/kotaemon/kotaemon/indices/rankings/cohere.py index b4ce97e..9515d12 100644 --- a/libs/kotaemon/kotaemon/indices/rankings/cohere.py +++ b/libs/kotaemon/kotaemon/indices/rankings/cohere.py @@ -39,7 +39,7 @@ class CohereReranking(BaseReranking): print("Cannot get Cohere API key from `ktem`", e) if not self.cohere_api_key: - print("Cohere API key not found. Skipping reranking.") + print("Cohere API key not found. Skipping rerankings.") return documents cohere_client = cohere.Client(self.cohere_api_key) @@ -52,10 +52,9 @@ class CohereReranking(BaseReranking): response = cohere_client.rerank( model=self.model_name, query=query, documents=_docs ) - # print("Cohere score", [r.relevance_score for r in response.results]) for r in response.results: doc = documents[r.index] - doc.metadata["cohere_reranking_score"] = r.relevance_score + doc.metadata["reranking_score"] = r.relevance_score compressed_docs.append(doc) return compressed_docs diff --git a/libs/kotaemon/kotaemon/indices/vectorindex.py b/libs/kotaemon/kotaemon/indices/vectorindex.py index e2984c7..1906091 100644 --- a/libs/kotaemon/kotaemon/indices/vectorindex.py +++ b/libs/kotaemon/kotaemon/indices/vectorindex.py @@ -241,7 +241,7 @@ class VectorRetrieval(BaseRetrieval): # if reranker is LLMReranking, limit the document with top_k items only if isinstance(reranker, LLMReranking): result = self._filter_docs(result, top_k=top_k) - result = reranker(documents=result, query=text) + result = reranker.run(documents=result, query=text) result = self._filter_docs(result, top_k=top_k) print(f"Got raw {len(result)} retrieved documents") diff --git a/libs/kotaemon/kotaemon/rerankings/__init__.py b/libs/kotaemon/kotaemon/rerankings/__init__.py new file mode 100644 index 0000000..621b9a2 --- /dev/null +++ b/libs/kotaemon/kotaemon/rerankings/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseReranking +from .cohere import CohereReranking +from .tei_fast_rerank import TeiFastReranking + +__all__ = ["BaseReranking", "TeiFastReranking", "CohereReranking"] diff --git a/libs/kotaemon/kotaemon/rerankings/base.py b/libs/kotaemon/kotaemon/rerankings/base.py new file mode 100644 index 0000000..c9c0b9b --- /dev/null +++ b/libs/kotaemon/kotaemon/rerankings/base.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from abc import abstractmethod + +from kotaemon.base import BaseComponent, Document + + +class BaseReranking(BaseComponent): + @abstractmethod + def run(self, documents: list[Document], query: str) -> list[Document]: + """Main method to transform list of documents + (re-ranking, filtering, etc)""" + ... diff --git a/libs/kotaemon/kotaemon/rerankings/cohere.py b/libs/kotaemon/kotaemon/rerankings/cohere.py new file mode 100644 index 0000000..dbc5e9a --- /dev/null +++ b/libs/kotaemon/kotaemon/rerankings/cohere.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from decouple import config + +from kotaemon.base import Document, Param + +from .base import BaseReranking + + +class CohereReranking(BaseReranking): + """Cohere Reranking model""" + + model_name: str = Param( + "rerank-multilingual-v2.0", + help=( + "ID of the model to use. You can go to [Supported Models]" + "(https://docs.cohere.com/docs/rerank-2) to see the supported models" + ), + required=True, + ) + cohere_api_key: str = Param( + config("COHERE_API_KEY", ""), + help="Cohere API key", + required=True, + ) + + def run(self, documents: list[Document], query: str) -> list[Document]: + """Use Cohere Reranker model to re-order documents + with their relevance score""" + try: + import cohere + except ImportError: + raise ImportError( + "Please install Cohere " "`pip install cohere` to use Cohere Reranking" + ) + + if not self.cohere_api_key: + print("Cohere API key not found. Skipping rerankings.") + return documents + + cohere_client = cohere.Client(self.cohere_api_key) + compressed_docs: list[Document] = [] + + if not documents: # to avoid empty api call + return compressed_docs + + _docs = [d.content for d in documents] + response = cohere_client.rerank( + model=self.model_name, query=query, documents=_docs + ) + for r in response.results: + doc = documents[r.index] + doc.metadata["reranking_score"] = r.relevance_score + compressed_docs.append(doc) + + return compressed_docs diff --git a/libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py b/libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py new file mode 100644 index 0000000..4ac4b8e --- /dev/null +++ b/libs/kotaemon/kotaemon/rerankings/tei_fast_rerank.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from typing import Optional + +import requests + +from kotaemon.base import Document, Param + +from .base import BaseReranking + +session = requests.session() + + +class TeiFastReranking(BaseReranking): + """Text Embeddings Inference (TEI) Reranking model + (https://huggingface.co/docs/text-embeddings-inference/en/index) + """ + + endpoint_url: str = Param( + None, help="TEI Reranking service api base URL", required=True + ) + model_name: Optional[str] = Param( + None, + help=( + "ID of the model to use. You can go to [Supported Models]" + "(https://github.com/huggingface" + "/text-embeddings-inference?tab=readme-ov-file" + "#supported-models) to see the supported models" + ), + ) + is_truncated: Optional[bool] = Param(True, help="Whether to truncate the inputs") + + def client(self, query, texts): + response = session.post( + url=self.endpoint_url, + json={ + "query": query, + "texts": texts, + "is_truncated": self.is_truncated, # default is True + }, + ).json() + return response + + def run(self, documents: list[Document], query: str) -> list[Document]: + """Use the deployed TEI rerankings service to re-order documents + with their relevance score""" + if not self.endpoint_url: + print("TEI API reranking URL not found. Skipping rerankings.") + return documents + + compressed_docs: list[Document] = [] + + if not documents: # to avoid empty api call + return compressed_docs + + if isinstance(documents[0], str): + documents = self.prepare_input(documents) + + batch_size = 6 + num_batch = max(len(documents) // batch_size, 1) + for i in range(num_batch): + if i == num_batch - 1: + mini_batch = documents[batch_size * i :] + else: + mini_batch = documents[batch_size * i : batch_size * (i + 1)] + + _docs = [d.content for d in mini_batch] + rerank_resp = self.client(query, _docs) + for r in rerank_resp: + doc = mini_batch[r["index"]] + doc.metadata["reranking_score"] = r["score"] + compressed_docs.append(doc) + + compressed_docs = sorted( + compressed_docs, key=lambda x: x.metadata["reranking_score"], reverse=True + ) + return compressed_docs diff --git a/libs/ktem/ktem/embeddings/manager.py b/libs/ktem/ktem/embeddings/manager.py index 88cdacb..c33d151 100644 --- a/libs/ktem/ktem/embeddings/manager.py +++ b/libs/ktem/ktem/embeddings/manager.py @@ -59,6 +59,7 @@ class EmbeddingManager: LCCohereEmbeddings, LCHuggingFaceEmbeddings, OpenAIEmbeddings, + TeiEndpointEmbeddings, ) self._vendors = [ @@ -67,6 +68,7 @@ class EmbeddingManager: FastEmbedEmbeddings, LCCohereEmbeddings, LCHuggingFaceEmbeddings, + TeiEndpointEmbeddings, ] def __getitem__(self, key: str) -> BaseEmbeddings: diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index 598064a..cd94852 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -16,6 +16,7 @@ import tiktoken from ktem.db.models import engine from ktem.embeddings.manager import embedding_models_manager from ktem.llms.manager import llms +from ktem.rerankings.manager import reranking_models_manager from llama_index.core.readers.base import BaseReader from llama_index.core.readers.file.base import default_file_metadata_func from llama_index.core.vector_stores import ( @@ -39,12 +40,7 @@ from kotaemon.indices.ingests.files import ( azure_reader, unstructured, ) -from kotaemon.indices.rankings import ( - BaseReranking, - CohereReranking, - LLMReranking, - LLMTrulensScoring, -) +from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring from kotaemon.indices.splitters import BaseSplitter, TokenSplitter from .base import BaseFileIndexIndexing, BaseFileIndexRetriever @@ -285,7 +281,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): ], retrieval_mode=user_settings["retrieval_mode"], llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None), - rerankers=[CohereReranking(use_key_from_ktem=True)], + rerankers=[ + reranking_models_manager[ + index_settings.get( + "reranking", reranking_models_manager.get_default_name() + ) + ] + ], ) if not user_settings["use_reranking"]: retriever.rerankers = [] # type: ignore @@ -715,7 +717,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing): for idx, file_path in enumerate(file_paths): file_path = Path(file_path) yield Document( - content=f"Indexing [{idx+1}/{n_files}]: {file_path.name}", + content=f"Indexing [{idx + 1}/{n_files}]: {file_path.name}", channel="debug", ) diff --git a/libs/ktem/ktem/pages/resources/__init__.py b/libs/ktem/ktem/pages/resources/__init__.py index aa606c9..35bf54c 100644 --- a/libs/ktem/ktem/pages/resources/__init__.py +++ b/libs/ktem/ktem/pages/resources/__init__.py @@ -4,6 +4,7 @@ from ktem.db.models import User, engine from ktem.embeddings.ui import EmbeddingManagement from ktem.index.ui import IndexManagement from ktem.llms.ui import LLMManagement +from ktem.rerankings.ui import RerankingManagement from sqlmodel import Session, select from .user import UserManagement @@ -24,6 +25,9 @@ class ResourcesTab(BasePage): with gr.Tab("Embeddings") as self.emb_management_tab: self.emb_management = EmbeddingManagement(self._app) + with gr.Tab("Rerankings") as self.rerank_management_tab: + self.rerank_management = RerankingManagement(self._app) + if self._app.f_user_management: with gr.Tab("Users", visible=False) as self.user_management_tab: self.user_management = UserManagement(self._app) diff --git a/libs/ktem/ktem/pages/setup.py b/libs/ktem/ktem/pages/setup.py index 5199ec4..f7e70a1 100644 --- a/libs/ktem/ktem/pages/setup.py +++ b/libs/ktem/ktem/pages/setup.py @@ -5,6 +5,7 @@ import requests from ktem.app import BasePage from ktem.embeddings.manager import embedding_models_manager as embeddings from ktem.llms.manager import llms +from ktem.rerankings.manager import reranking_models_manager as rerankers from theflow.settings import settings as flowsettings KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) @@ -186,6 +187,15 @@ class SetupPage(BasePage): }, default=True, ) + rerankers.update( + name="cohere", + spec={ + "__type__": "kotaemon.rerankings.CohereReranking", + "model_name": "rerank-multilingual-v2.0", + "cohere_api_key": cohere_api_key, + }, + default=True, + ) elif radio_model_value == "openai": if openai_api_key: llms.update( diff --git a/libs/ktem/ktem/reasoning/react.py b/libs/ktem/ktem/reasoning/react.py index afdd931..d73a568 100644 --- a/libs/ktem/ktem/reasoning/react.py +++ b/libs/ktem/ktem/reasoning/react.py @@ -100,7 +100,7 @@ class DocSearchTool(BaseTool): ) print("Retrieved #{}: {}".format(_id, retrieved_content[:100])) - print("Score", retrieved_item.metadata.get("cohere_reranking_score", None)) + print("Score", retrieved_item.metadata.get("reranking_score", None)) # trim context by trim_len if evidence: diff --git a/libs/ktem/ktem/reasoning/rewoo.py b/libs/ktem/ktem/reasoning/rewoo.py index e4d461f..f751342 100644 --- a/libs/ktem/ktem/reasoning/rewoo.py +++ b/libs/ktem/ktem/reasoning/rewoo.py @@ -138,7 +138,7 @@ class DocSearchTool(BaseTool): ) print("Retrieved #{}: {}".format(_id, retrieved_content)) - print("Score", retrieved_item.metadata.get("cohere_reranking_score", None)) + print("Score", retrieved_item.metadata.get("reranking_score", None)) # trim context by trim_len if evidence: diff --git a/libs/ktem/ktem/rerankings/__init__.py b/libs/ktem/ktem/rerankings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/ktem/ktem/rerankings/db.py b/libs/ktem/ktem/rerankings/db.py new file mode 100644 index 0000000..85049df --- /dev/null +++ b/libs/ktem/ktem/rerankings/db.py @@ -0,0 +1,36 @@ +from typing import Type + +from ktem.db.engine import engine +from sqlalchemy import JSON, Boolean, Column, String +from sqlalchemy.orm import DeclarativeBase +from theflow.settings import settings as flowsettings +from theflow.utils.modules import import_dotted_string + + +class Base(DeclarativeBase): + pass + + +class BaseRerankingTable(Base): + """Base table to store rerankings model""" + + __abstract__ = True + + name = Column(String, primary_key=True, unique=True) + spec = Column(JSON, default={}) + default = Column(Boolean, default=False) + + +__base_reranking: Type[BaseRerankingTable] = ( + import_dotted_string(flowsettings.KH_TABLE_RERANKING, safe=False) + if hasattr(flowsettings, "KH_TABLE_RERANKING") + else BaseRerankingTable +) + + +class RerankingTable(__base_reranking): # type: ignore + __tablename__ = "reranking" + + +if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False): + RerankingTable.metadata.create_all(engine) diff --git a/libs/ktem/ktem/rerankings/manager.py b/libs/ktem/ktem/rerankings/manager.py new file mode 100644 index 0000000..a9facc1 --- /dev/null +++ b/libs/ktem/ktem/rerankings/manager.py @@ -0,0 +1,194 @@ +from typing import Optional, Type + +from sqlalchemy import select +from sqlalchemy.orm import Session +from theflow.settings import settings as flowsettings +from theflow.utils.modules import deserialize + +from kotaemon.rerankings.base import BaseReranking + +from .db import RerankingTable, engine + + +class RerankingManager: + """Represent a pool of rerankings models""" + + def __init__(self): + self._models: dict[str, BaseReranking] = {} + self._info: dict[str, dict] = {} + self._default: str = "" + self._vendors: list[Type] = [] + + # populate the pool if empty + if hasattr(flowsettings, "KH_RERANKINGS"): + with Session(engine) as sess: + count = sess.query(RerankingTable).count() + if not count: + for name, model in flowsettings.KH_RERANKINGS.items(): + self.add( + name=name, + spec=model["spec"], + default=model.get("default", False), + ) + + self.load() + self.load_vendors() + + def load(self): + """Load the model pool from database""" + self._models, self._info, self._default = {}, {}, "" + with Session(engine) as sess: + stmt = select(RerankingTable) + items = sess.execute(stmt) + + for (item,) in items: + self._models[item.name] = deserialize(item.spec, safe=False) + self._info[item.name] = { + "name": item.name, + "spec": item.spec, + "default": item.default, + } + if item.default: + self._default = item.name + + def load_vendors(self): + from kotaemon.rerankings import CohereReranking, TeiFastReranking + + self._vendors = [TeiFastReranking, CohereReranking] + + def __getitem__(self, key: str) -> BaseReranking: + """Get model by name""" + return self._models[key] + + def __contains__(self, key: str) -> bool: + """Check if model exists""" + return key in self._models + + def get( + self, key: str, default: Optional[BaseReranking] = None + ) -> Optional[BaseReranking]: + """Get model by name with default value""" + return self._models.get(key, default) + + def settings(self) -> dict: + """Present model pools option for gradio""" + return { + "label": "Reranking", + "choices": list(self._models.keys()), + "value": self.get_default_name(), + } + + def options(self) -> dict: + """Present a dict of models""" + return self._models + + def get_random_name(self) -> str: + """Get the name of random model + + Returns: + str: random model name in the pool + """ + import random + + if not self._models: + raise ValueError("No models is pool") + + return random.choice(list(self._models.keys())) + + def get_default_name(self) -> str: + """Get the name of default model + + In case there is no default model, choose random model from pool. In + case there are multiple default models, choose random from them. + + Returns: + str: model name + """ + if not self._models: + raise ValueError("No models in pool") + + if not self._default: + return self.get_random_name() + + return self._default + + def get_random(self) -> BaseReranking: + """Get random model""" + return self._models[self.get_random_name()] + + def get_default(self) -> BaseReranking: + """Get default model + + In case there is no default model, choose random model from pool. In + case there are multiple default models, choose random from them. + + Returns: + BaseReranking: model + """ + return self._models[self.get_default_name()] + + def info(self) -> dict: + """List all models""" + return self._info + + def add(self, name: str, spec: dict, default: bool): + if not name: + raise ValueError("Name must not be empty") + + try: + with Session(engine) as sess: + if default: + # turn all models to non-default + sess.query(RerankingTable).update({"default": False}) + sess.commit() + + item = RerankingTable(name=name, spec=spec, default=default) + sess.add(item) + sess.commit() + except Exception as e: + raise ValueError(f"Failed to add model {name}: {e}") + + self.load() + + def delete(self, name: str): + """Delete a model from the pool""" + try: + with Session(engine) as sess: + item = sess.query(RerankingTable).filter_by(name=name).first() + sess.delete(item) + sess.commit() + except Exception as e: + raise ValueError(f"Failed to delete model {name}: {e}") + + self.load() + + def update(self, name: str, spec: dict, default: bool): + """Update a model in the pool""" + if not name: + raise ValueError("Name must not be empty") + + try: + with Session(engine) as sess: + + if default: + # turn all models to non-default + sess.query(RerankingTable).update({"default": False}) + sess.commit() + + item = sess.query(RerankingTable).filter_by(name=name).first() + if not item: + raise ValueError(f"Model {name} not found") + item.spec = spec + item.default = default + sess.commit() + except Exception as e: + raise ValueError(f"Failed to update model {name}: {e}") + + self.load() + + def vendors(self) -> dict: + """Return list of vendors""" + return {vendor.__qualname__: vendor for vendor in self._vendors} + + +reranking_models_manager = RerankingManager() diff --git a/libs/ktem/ktem/rerankings/ui.py b/libs/ktem/ktem/rerankings/ui.py new file mode 100644 index 0000000..311a794 --- /dev/null +++ b/libs/ktem/ktem/rerankings/ui.py @@ -0,0 +1,390 @@ +from copy import deepcopy + +import gradio as gr +import pandas as pd +import yaml +from ktem.app import BasePage +from ktem.utils.file import YAMLNoDateSafeLoader +from theflow.utils.modules import deserialize + +from .manager import reranking_models_manager + + +def format_description(cls): + params = cls.describe()["params"] + params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"] + for key, value in params.items(): + if isinstance(value["auto_callback"], str): + continue + params_lines.append(f"| {key} | {value['type']} | {value['help']} |") + return f"{cls.__doc__}\n\n" + "\n".join(params_lines) + + +class RerankingManagement(BasePage): + def __init__(self, app): + self._app = app + self.spec_desc_default = ( + "# Spec description\n\nSelect a model to view the spec description." + ) + self.on_building_ui() + + def on_building_ui(self): + with gr.Tab(label="View"): + self.rerank_list = gr.DataFrame( + headers=["name", "vendor", "default"], + interactive=False, + ) + + with gr.Column(visible=False) as self._selected_panel: + self.selected_rerank_name = gr.Textbox(value="", visible=False) + with gr.Row(): + with gr.Column(): + self.edit_default = gr.Checkbox( + label="Set default", + info=( + "Set this Reranking model as default. This default " + "Reranking will be used by other components by default " + "if no Reranking is specified for such components." + ), + ) + self.edit_spec = gr.Textbox( + label="Specification", + info="Specification of the Embedding model in YAML format", + lines=10, + ) + + with gr.Accordion( + label="Test connection", visible=False, open=False + ) as self._check_connection_panel: + with gr.Row(): + with gr.Column(scale=4): + self.connection_logs = gr.HTML( + "Logs", + ) + + with gr.Column(scale=1): + self.btn_test_connection = gr.Button("Test") + + with gr.Row(visible=False) as self._selected_panel_btn: + with gr.Column(): + self.btn_edit_save = gr.Button( + "Save", min_width=10, variant="primary" + ) + with gr.Column(): + self.btn_delete = gr.Button( + "Delete", min_width=10, variant="stop" + ) + with gr.Row(): + self.btn_delete_yes = gr.Button( + "Confirm Delete", + variant="stop", + visible=False, + min_width=10, + ) + self.btn_delete_no = gr.Button( + "Cancel", visible=False, min_width=10 + ) + with gr.Column(): + self.btn_close = gr.Button("Close", min_width=10) + + with gr.Column(): + self.edit_spec_desc = gr.Markdown("# Spec description") + + with gr.Tab(label="Add"): + with gr.Row(): + with gr.Column(scale=2): + self.name = gr.Textbox( + label="Name", + info=( + "Must be unique and non-empty. " + "The name will be used to identify the reranking model." + ), + ) + self.rerank_choices = gr.Dropdown( + label="Vendors", + info=( + "Choose the vendor of the Reranking model. Each vendor " + "has different specification." + ), + ) + self.spec = gr.Textbox( + label="Specification", + info="Specification of the Embedding model in YAML format.", + ) + self.default = gr.Checkbox( + label="Set default", + info=( + "Set this Reranking model as default. This default " + "Reranking will be used by other components by default " + "if no Reranking is specified for such components." + ), + ) + self.btn_new = gr.Button("Add", variant="primary") + + with gr.Column(scale=3): + self.spec_desc = gr.Markdown(self.spec_desc_default) + + def _on_app_created(self): + """Called when the app is created""" + self._app.app.load( + self.list_rerankings, + inputs=[], + outputs=[self.rerank_list], + ) + self._app.app.load( + lambda: gr.update(choices=list(reranking_models_manager.vendors().keys())), + outputs=[self.rerank_choices], + ) + + def on_rerank_vendor_change(self, vendor): + vendor = reranking_models_manager.vendors()[vendor] + + required: dict = {} + desc = vendor.describe() + for key, value in desc["params"].items(): + if value.get("required", False): + required[key] = value.get("default", None) + + return yaml.dump(required), format_description(vendor) + + def on_register_events(self): + self.rerank_choices.select( + self.on_rerank_vendor_change, + inputs=[self.rerank_choices], + outputs=[self.spec, self.spec_desc], + ) + self.btn_new.click( + self.create_rerank, + inputs=[self.name, self.rerank_choices, self.spec, self.default], + outputs=None, + ).success(self.list_rerankings, inputs=[], outputs=[self.rerank_list]).success( + lambda: ("", None, "", False, self.spec_desc_default), + outputs=[ + self.name, + self.rerank_choices, + self.spec, + self.default, + self.spec_desc, + ], + ) + self.rerank_list.select( + self.select_rerank, + inputs=self.rerank_list, + outputs=[self.selected_rerank_name], + show_progress="hidden", + ) + self.selected_rerank_name.change( + self.on_selected_rerank_change, + inputs=[self.selected_rerank_name], + outputs=[ + self._selected_panel, + self._selected_panel_btn, + # delete section + self.btn_delete, + self.btn_delete_yes, + self.btn_delete_no, + # edit section + self.edit_spec, + self.edit_spec_desc, + self.edit_default, + self._check_connection_panel, + ], + show_progress="hidden", + ).success(lambda: gr.update(value=""), outputs=[self.connection_logs]) + + self.btn_delete.click( + self.on_btn_delete_click, + inputs=[], + outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no], + show_progress="hidden", + ) + self.btn_delete_yes.click( + self.delete_rerank, + inputs=[self.selected_rerank_name], + outputs=[self.selected_rerank_name], + show_progress="hidden", + ).then( + self.list_rerankings, + inputs=[], + outputs=[self.rerank_list], + ) + self.btn_delete_no.click( + lambda: ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False), + ), + inputs=[], + outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no], + show_progress="hidden", + ) + self.btn_edit_save.click( + self.save_rerank, + inputs=[ + self.selected_rerank_name, + self.edit_default, + self.edit_spec, + ], + show_progress="hidden", + ).then( + self.list_rerankings, + inputs=[], + outputs=[self.rerank_list], + ) + self.btn_close.click(lambda: "", outputs=[self.selected_rerank_name]) + + self.btn_test_connection.click( + self.check_connection, + inputs=[self.selected_rerank_name, self.edit_spec], + outputs=[self.connection_logs], + ) + + def create_rerank(self, name, choices, spec, default): + try: + spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader) + spec["__type__"] = ( + reranking_models_manager.vendors()[choices].__module__ + + "." + + reranking_models_manager.vendors()[choices].__qualname__ + ) + + reranking_models_manager.add(name, spec=spec, default=default) + gr.Info(f'Create Reranking model "{name}" successfully') + except Exception as e: + raise gr.Error(f"Failed to create Reranking model {name}: {e}") + + def list_rerankings(self): + """List the Reranking models""" + items = [] + for item in reranking_models_manager.info().values(): + record = {} + record["name"] = item["name"] + record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1] + record["default"] = item["default"] + items.append(record) + + if items: + rerank_list = pd.DataFrame.from_records(items) + else: + rerank_list = pd.DataFrame.from_records( + [{"name": "-", "vendor": "-", "default": "-"}] + ) + + return rerank_list + + def select_rerank(self, rerank_list, ev: gr.SelectData): + if ev.value == "-" and ev.index[0] == 0: + gr.Info("No reranking model is loaded. Please add first") + return "" + + if not ev.selected: + return "" + + return rerank_list["name"][ev.index[0]] + + def on_selected_rerank_change(self, selected_rerank_name): + if selected_rerank_name == "": + _check_connection_panel = gr.update(visible=False) + _selected_panel = gr.update(visible=False) + _selected_panel_btn = gr.update(visible=False) + btn_delete = gr.update(visible=True) + btn_delete_yes = gr.update(visible=False) + btn_delete_no = gr.update(visible=False) + edit_spec = gr.update(value="") + edit_spec_desc = gr.update(value="") + edit_default = gr.update(value=False) + else: + _check_connection_panel = gr.update(visible=True) + _selected_panel = gr.update(visible=True) + _selected_panel_btn = gr.update(visible=True) + btn_delete = gr.update(visible=True) + btn_delete_yes = gr.update(visible=False) + btn_delete_no = gr.update(visible=False) + + info = deepcopy(reranking_models_manager.info()[selected_rerank_name]) + vendor_str = info["spec"].pop("__type__", "-").split(".")[-1] + vendor = reranking_models_manager.vendors()[vendor_str] + + edit_spec = yaml.dump(info["spec"]) + edit_spec_desc = format_description(vendor) + edit_default = info["default"] + + return ( + _selected_panel, + _selected_panel_btn, + btn_delete, + btn_delete_yes, + btn_delete_no, + edit_spec, + edit_spec_desc, + edit_default, + _check_connection_panel, + ) + + def on_btn_delete_click(self): + btn_delete = gr.update(visible=False) + btn_delete_yes = gr.update(visible=True) + btn_delete_no = gr.update(visible=True) + + return btn_delete, btn_delete_yes, btn_delete_no + + def check_connection(self, selected_rerank_name, selected_spec): + log_content: str = "" + try: + log_content += f"- Testing model: {selected_rerank_name}
" + yield log_content + + # Parse content & init model + info = deepcopy(reranking_models_manager.info()[selected_rerank_name]) + + # Parse content & create dummy response + spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader) + info["spec"].update(spec) + + rerank = deserialize(info["spec"], safe=False) + + if rerank is None: + raise Exception(f"Can not found model: {selected_rerank_name}") + + log_content += "- Sending a message ([`Hello`], `Hi`)
" + yield log_content + _ = rerank(["Hello"], "Hi") + + log_content += ( + "- Connection success. " + "
" + ) + yield log_content + + gr.Info(f"Embedding {selected_rerank_name} connect successfully") + except Exception as e: + print(e) + log_content += ( + f"- Connection failed. " + f"Got error:\n {str(e)}" + ) + yield log_content + + return log_content + + def save_rerank(self, selected_rerank_name, default, spec): + try: + spec = yaml.load(spec, Loader=YAMLNoDateSafeLoader) + spec["__type__"] = reranking_models_manager.info()[selected_rerank_name][ + "spec" + ]["__type__"] + reranking_models_manager.update( + selected_rerank_name, spec=spec, default=default + ) + gr.Info(f'Save Reranking model "{selected_rerank_name}" successfully') + except Exception as e: + gr.Error(f'Failed to save Embedding model "{selected_rerank_name}": {e}') + + def delete_rerank(self, selected_rerank_name): + try: + reranking_models_manager.delete(selected_rerank_name) + except Exception as e: + gr.Error(f'Failed to delete Reranking model "{selected_rerank_name}": {e}') + return selected_rerank_name + + return "" diff --git a/libs/ktem/ktem/utils/render.py b/libs/ktem/ktem/utils/render.py index 9176627..3e6c434 100644 --- a/libs/ktem/ktem/utils/render.py +++ b/libs/ktem/ktem/utils/render.py @@ -154,9 +154,9 @@ class Render: if doc.metadata.get("llm_trulens_score") is not None else 0.0 ) - cohere_reranking_score = ( - round(doc.metadata["cohere_reranking_score"], 2) - if doc.metadata.get("cohere_reranking_score") is not None + reranking_score = ( + round(doc.metadata["reranking_score"], 2) + if doc.metadata.get("reranking_score") is not None else 0.0 ) item_type_prefix = doc.metadata.get("type", "") @@ -166,8 +166,8 @@ class Render: if llm_reranking_score > 0: relevant_score = llm_reranking_score - elif cohere_reranking_score > 0: - relevant_score = cohere_reranking_score + elif reranking_score > 0: + relevant_score = reranking_score else: relevant_score = 0.0 @@ -179,7 +179,7 @@ class Render: "  LLM relevant score:" f" {llm_reranking_score}
" "  Reranking score:" - f" {cohere_reranking_score}
", + f" {reranking_score}
", ) text = doc.text if not override_text else override_text