From a8725710af8608200f270a6172e250b97b643c11 Mon Sep 17 00:00:00 2001 From: "Duc Nguyen (john)" Date: Thu, 25 Apr 2024 17:18:12 +0700 Subject: [PATCH] Allow users to select reasoning pipeline. Fix small issues with user UI, cohere name (#50) * Fix user page * Allow changing LLM in reasoning pipeline * Fix CohereEmbedding name --- libs/kotaemon/kotaemon/embeddings/__init__.py | 4 +-- .../kotaemon/embeddings/langchain_based.py | 2 +- libs/kotaemon/tests/test_embedding_models.py | 4 +-- libs/ktem/ktem/llms/manager.py | 11 ++++++- libs/ktem/ktem/pages/resources/user.py | 4 +-- libs/ktem/ktem/reasoning/simple.py | 32 ++++++++++++++++--- 6 files changed, 44 insertions(+), 13 deletions(-) diff --git a/libs/kotaemon/kotaemon/embeddings/__init__.py b/libs/kotaemon/kotaemon/embeddings/__init__.py index bf5459b..af01ecc 100644 --- a/libs/kotaemon/kotaemon/embeddings/__init__.py +++ b/libs/kotaemon/kotaemon/embeddings/__init__.py @@ -3,7 +3,7 @@ from .endpoint_based import EndpointEmbeddings from .fastembed import FastEmbedEmbeddings from .langchain_based import ( LCAzureOpenAIEmbeddings, - LCCohereEmbdeddings, + LCCohereEmbeddings, LCHuggingFaceEmbeddings, LCOpenAIEmbeddings, ) @@ -14,7 +14,7 @@ __all__ = [ "EndpointEmbeddings", "LCOpenAIEmbeddings", "LCAzureOpenAIEmbeddings", - "LCCohereEmbdeddings", + "LCCohereEmbeddings", "LCHuggingFaceEmbeddings", "OpenAIEmbeddings", "AzureOpenAIEmbeddings", diff --git a/libs/kotaemon/kotaemon/embeddings/langchain_based.py b/libs/kotaemon/kotaemon/embeddings/langchain_based.py index 3320ff5..aa2bb04 100644 --- a/libs/kotaemon/kotaemon/embeddings/langchain_based.py +++ b/libs/kotaemon/kotaemon/embeddings/langchain_based.py @@ -159,7 +159,7 @@ class LCAzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings): return AzureOpenAIEmbeddings -class LCCohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings): +class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings): """Wrapper around Langchain's Cohere embedding, focusing on key parameters""" def __init__( diff --git a/libs/kotaemon/tests/test_embedding_models.py b/libs/kotaemon/tests/test_embedding_models.py index 40cb750..2d63ec7 100644 --- a/libs/kotaemon/tests/test_embedding_models.py +++ b/libs/kotaemon/tests/test_embedding_models.py @@ -9,7 +9,7 @@ from kotaemon.embeddings import ( AzureOpenAIEmbeddings, FastEmbedEmbeddings, LCAzureOpenAIEmbeddings, - LCCohereEmbdeddings, + LCCohereEmbeddings, LCHuggingFaceEmbeddings, OpenAIEmbeddings, ) @@ -148,7 +148,7 @@ def test_lchuggingface_embeddings( side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]], ) def test_lccohere_embeddings(langchain_cohere_embedding_call): - model = LCCohereEmbdeddings( + model = LCCohereEmbeddings( model="embed-english-light-v2.0", cohere_api_key="my-api-key" ) diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py index 71ad425..6baa759 100644 --- a/libs/ktem/ktem/llms/manager.py +++ b/libs/ktem/ktem/llms/manager.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Optional, Type, overload from sqlalchemy import select from sqlalchemy.orm import Session @@ -71,6 +71,14 @@ class LLMManager: """Check if model exists""" return key in self._models + @overload + def get(self, key: str, default: None) -> Optional[ChatLLM]: + ... + + @overload + def get(self, key: str, default: ChatLLM) -> ChatLLM: + ... + def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]: """Get model by name with default value""" return self._models.get(key, default) @@ -138,6 +146,7 @@ class LLMManager: def add(self, name: str, spec: dict, default: bool): """Add a new model to the pool""" + name = name.strip() if not name: raise ValueError("Name must not be empty") diff --git a/libs/ktem/ktem/pages/resources/user.py b/libs/ktem/ktem/pages/resources/user.py index 6411b30..2b65075 100644 --- a/libs/ktem/ktem/pages/resources/user.py +++ b/libs/ktem/ktem/pages/resources/user.py @@ -142,7 +142,7 @@ class UserManagement(BasePage): ) self.admin_edit = gr.Checkbox(label="Admin") - with gr.Row() as self._selected_panel_btn: + with gr.Row(visible=False) as self._selected_panel_btn: with gr.Column(): self.btn_edit_save = gr.Button("Save") with gr.Column(): @@ -338,7 +338,7 @@ class UserManagement(BasePage): if not ev.selected: return -1 - return user_list["id"][ev.index[0]] + return int(user_list["id"][ev.index[0]]) def on_selected_user_change(self, selected_user_id): if selected_user_id == -1: diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 1643c50..5118b8f 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -680,12 +680,15 @@ class FullQAPipeline(BaseReasoning): retrievers: the retrievers to use """ prefix = f"reasoning.options.{cls.get_info()['id']}" - pipeline = FullQAPipeline(retrievers=retrievers) + pipeline = cls(retrievers=retrievers) + + llm_name = settings.get(f"{prefix}.llm", None) + llm = llms.get(llm_name, llms.get_default()) # answering pipeline configuration answer_pipeline = pipeline.answering_pipeline - answer_pipeline.llm = llms.get_default() - answer_pipeline.citation_pipeline.llm = llms.get_default() + answer_pipeline.llm = llm + answer_pipeline.citation_pipeline.llm = llm answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"] answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"] answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"] @@ -694,14 +697,14 @@ class FullQAPipeline(BaseReasoning): settings["reasoning.lang"], "English" ) - pipeline.add_query_context.llm = llms.get_default() + pipeline.add_query_context.llm = llm pipeline.add_query_context.n_last_interactions = settings[ f"{prefix}.n_last_interactions" ] pipeline.trigger_context = settings[f"{prefix}.trigger_context"] pipeline.use_rewrite = states.get("app", {}).get("regen", False) - pipeline.rewrite_pipeline.llm = llms.get_default() + pipeline.rewrite_pipeline.llm = llm pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( settings["reasoning.lang"], "English" ) @@ -709,7 +712,26 @@ class FullQAPipeline(BaseReasoning): @classmethod def get_user_settings(cls) -> dict: + from ktem.llms.manager import llms + + llm = "" + choices = [("(default)", "")] + try: + choices += [(_, _) for _ in llms.options().keys()] + except Exception as e: + logger.exception(f"Failed to get LLM options: {e}") + return { + "llm": { + "name": "Language model", + "value": llm, + "component": "dropdown", + "choices": choices, + "info": ( + "The language model to use for generating the answer. If None, " + "the application default language model will be used." + ), + }, "highlight_citation": { "name": "Highlight Citation", "value": False,