Allow users to select reasoning pipeline. Fix small issues with user UI, cohere name (#50)
* Fix user page * Allow changing LLM in reasoning pipeline * Fix CohereEmbedding name
This commit is contained in:
parent
e29bec6275
commit
a8725710af
|
@ -3,7 +3,7 @@ from .endpoint_based import EndpointEmbeddings
|
||||||
from .fastembed import FastEmbedEmbeddings
|
from .fastembed import FastEmbedEmbeddings
|
||||||
from .langchain_based import (
|
from .langchain_based import (
|
||||||
LCAzureOpenAIEmbeddings,
|
LCAzureOpenAIEmbeddings,
|
||||||
LCCohereEmbdeddings,
|
LCCohereEmbeddings,
|
||||||
LCHuggingFaceEmbeddings,
|
LCHuggingFaceEmbeddings,
|
||||||
LCOpenAIEmbeddings,
|
LCOpenAIEmbeddings,
|
||||||
)
|
)
|
||||||
|
@ -14,7 +14,7 @@ __all__ = [
|
||||||
"EndpointEmbeddings",
|
"EndpointEmbeddings",
|
||||||
"LCOpenAIEmbeddings",
|
"LCOpenAIEmbeddings",
|
||||||
"LCAzureOpenAIEmbeddings",
|
"LCAzureOpenAIEmbeddings",
|
||||||
"LCCohereEmbdeddings",
|
"LCCohereEmbeddings",
|
||||||
"LCHuggingFaceEmbeddings",
|
"LCHuggingFaceEmbeddings",
|
||||||
"OpenAIEmbeddings",
|
"OpenAIEmbeddings",
|
||||||
"AzureOpenAIEmbeddings",
|
"AzureOpenAIEmbeddings",
|
||||||
|
|
|
@ -159,7 +159,7 @@ class LCAzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
return AzureOpenAIEmbeddings
|
return AzureOpenAIEmbeddings
|
||||||
|
|
||||||
|
|
||||||
class LCCohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
class LCCohereEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
|
"""Wrapper around Langchain's Cohere embedding, focusing on key parameters"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -9,7 +9,7 @@ from kotaemon.embeddings import (
|
||||||
AzureOpenAIEmbeddings,
|
AzureOpenAIEmbeddings,
|
||||||
FastEmbedEmbeddings,
|
FastEmbedEmbeddings,
|
||||||
LCAzureOpenAIEmbeddings,
|
LCAzureOpenAIEmbeddings,
|
||||||
LCCohereEmbdeddings,
|
LCCohereEmbeddings,
|
||||||
LCHuggingFaceEmbeddings,
|
LCHuggingFaceEmbeddings,
|
||||||
OpenAIEmbeddings,
|
OpenAIEmbeddings,
|
||||||
)
|
)
|
||||||
|
@ -148,7 +148,7 @@ def test_lchuggingface_embeddings(
|
||||||
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
side_effect=lambda *args, **kwargs: [[1.0, 2.1, 3.2]],
|
||||||
)
|
)
|
||||||
def test_lccohere_embeddings(langchain_cohere_embedding_call):
|
def test_lccohere_embeddings(langchain_cohere_embedding_call):
|
||||||
model = LCCohereEmbdeddings(
|
model = LCCohereEmbeddings(
|
||||||
model="embed-english-light-v2.0", cohere_api_key="my-api-key"
|
model="embed-english-light-v2.0", cohere_api_key="my-api-key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type, overload
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
@ -71,6 +71,14 @@ class LLMManager:
|
||||||
"""Check if model exists"""
|
"""Check if model exists"""
|
||||||
return key in self._models
|
return key in self._models
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: str, default: None) -> Optional[ChatLLM]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def get(self, key: str, default: ChatLLM) -> ChatLLM:
|
||||||
|
...
|
||||||
|
|
||||||
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
|
def get(self, key: str, default: Optional[ChatLLM] = None) -> Optional[ChatLLM]:
|
||||||
"""Get model by name with default value"""
|
"""Get model by name with default value"""
|
||||||
return self._models.get(key, default)
|
return self._models.get(key, default)
|
||||||
|
@ -138,6 +146,7 @@ class LLMManager:
|
||||||
|
|
||||||
def add(self, name: str, spec: dict, default: bool):
|
def add(self, name: str, spec: dict, default: bool):
|
||||||
"""Add a new model to the pool"""
|
"""Add a new model to the pool"""
|
||||||
|
name = name.strip()
|
||||||
if not name:
|
if not name:
|
||||||
raise ValueError("Name must not be empty")
|
raise ValueError("Name must not be empty")
|
||||||
|
|
||||||
|
|
|
@ -142,7 +142,7 @@ class UserManagement(BasePage):
|
||||||
)
|
)
|
||||||
self.admin_edit = gr.Checkbox(label="Admin")
|
self.admin_edit = gr.Checkbox(label="Admin")
|
||||||
|
|
||||||
with gr.Row() as self._selected_panel_btn:
|
with gr.Row(visible=False) as self._selected_panel_btn:
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
self.btn_edit_save = gr.Button("Save")
|
self.btn_edit_save = gr.Button("Save")
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
@ -338,7 +338,7 @@ class UserManagement(BasePage):
|
||||||
if not ev.selected:
|
if not ev.selected:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
return user_list["id"][ev.index[0]]
|
return int(user_list["id"][ev.index[0]])
|
||||||
|
|
||||||
def on_selected_user_change(self, selected_user_id):
|
def on_selected_user_change(self, selected_user_id):
|
||||||
if selected_user_id == -1:
|
if selected_user_id == -1:
|
||||||
|
|
|
@ -680,12 +680,15 @@ class FullQAPipeline(BaseReasoning):
|
||||||
retrievers: the retrievers to use
|
retrievers: the retrievers to use
|
||||||
"""
|
"""
|
||||||
prefix = f"reasoning.options.{cls.get_info()['id']}"
|
prefix = f"reasoning.options.{cls.get_info()['id']}"
|
||||||
pipeline = FullQAPipeline(retrievers=retrievers)
|
pipeline = cls(retrievers=retrievers)
|
||||||
|
|
||||||
|
llm_name = settings.get(f"{prefix}.llm", None)
|
||||||
|
llm = llms.get(llm_name, llms.get_default())
|
||||||
|
|
||||||
# answering pipeline configuration
|
# answering pipeline configuration
|
||||||
answer_pipeline = pipeline.answering_pipeline
|
answer_pipeline = pipeline.answering_pipeline
|
||||||
answer_pipeline.llm = llms.get_default()
|
answer_pipeline.llm = llm
|
||||||
answer_pipeline.citation_pipeline.llm = llms.get_default()
|
answer_pipeline.citation_pipeline.llm = llm
|
||||||
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
answer_pipeline.n_last_interactions = settings[f"{prefix}.n_last_interactions"]
|
||||||
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
answer_pipeline.enable_citation = settings[f"{prefix}.highlight_citation"]
|
||||||
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
answer_pipeline.system_prompt = settings[f"{prefix}.system_prompt"]
|
||||||
|
@ -694,14 +697,14 @@ class FullQAPipeline(BaseReasoning):
|
||||||
settings["reasoning.lang"], "English"
|
settings["reasoning.lang"], "English"
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline.add_query_context.llm = llms.get_default()
|
pipeline.add_query_context.llm = llm
|
||||||
pipeline.add_query_context.n_last_interactions = settings[
|
pipeline.add_query_context.n_last_interactions = settings[
|
||||||
f"{prefix}.n_last_interactions"
|
f"{prefix}.n_last_interactions"
|
||||||
]
|
]
|
||||||
|
|
||||||
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
|
pipeline.trigger_context = settings[f"{prefix}.trigger_context"]
|
||||||
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
|
||||||
pipeline.rewrite_pipeline.llm = llms.get_default()
|
pipeline.rewrite_pipeline.llm = llm
|
||||||
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
|
||||||
settings["reasoning.lang"], "English"
|
settings["reasoning.lang"], "English"
|
||||||
)
|
)
|
||||||
|
@ -709,7 +712,26 @@ class FullQAPipeline(BaseReasoning):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
|
||||||
|
llm = ""
|
||||||
|
choices = [("(default)", "")]
|
||||||
|
try:
|
||||||
|
choices += [(_, _) for _ in llms.options().keys()]
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Failed to get LLM options: {e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
"llm": {
|
||||||
|
"name": "Language model",
|
||||||
|
"value": llm,
|
||||||
|
"component": "dropdown",
|
||||||
|
"choices": choices,
|
||||||
|
"info": (
|
||||||
|
"The language model to use for generating the answer. If None, "
|
||||||
|
"the application default language model will be used."
|
||||||
|
),
|
||||||
|
},
|
||||||
"highlight_citation": {
|
"highlight_citation": {
|
||||||
"name": "Highlight Citation",
|
"name": "Highlight Citation",
|
||||||
"value": False,
|
"value": False,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user