From 3ed50b0f1089658478a2132808c3af518338bea1 Mon Sep 17 00:00:00 2001 From: "Duc Nguyen (john)" Date: Thu, 11 Apr 2024 07:50:53 +0700 Subject: [PATCH] Improve LLMs and Embedding models resources experience (#21) * Fix inconsistent default values * Disallow LLM's empty name. Handle LLM creation error on UI --- libs/kotaemon/kotaemon/embeddings/openai.py | 11 +++++++---- libs/ktem/ktem/llms/manager.py | 12 ++++++++++++ libs/ktem/ktem/llms/ui.py | 4 ++-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/libs/kotaemon/kotaemon/embeddings/openai.py b/libs/kotaemon/kotaemon/embeddings/openai.py index 6f9a246..74655dc 100644 --- a/libs/kotaemon/kotaemon/embeddings/openai.py +++ b/libs/kotaemon/kotaemon/embeddings/openai.py @@ -19,7 +19,7 @@ class BaseOpenAIEmbeddings(BaseEmbeddings): _dependencies = ["openai"] - api_key: str = Param(help="API key", required=True) + api_key: str = Param(None, help="API key", required=True) timeout: Optional[float] = Param(None, help="Timeout for the API request.") max_retries: Optional[int] = Param( None, help="Maximum number of retries for the API request." @@ -88,6 +88,7 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings): base_url: Optional[str] = Param(None, help="OpenAI base URL") organization: Optional[str] = Param(None, help="OpenAI organization") model: str = Param( + None, help=( "ID of the model to use. You can go to [Model overview](https://platform." "openai.com/docs/models/overview) to see the available models." @@ -131,14 +132,16 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings): class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings): azure_endpoint: str = Param( + None, help=( "HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, " "azure_deployment, and api_version parameters are used to construct " "the full URL for the Azure OpenAI model." - ) + ), + required=True, ) - azure_deployment: str = Param(help="Azure deployment name", required=True) - api_version: str = Param(help="Azure model version", required=True) + azure_deployment: str = Param(None, help="Azure deployment name", required=True) + api_version: str = Param(None, help="Azure model version", required=True) azure_ad_token: Optional[str] = Param(None, help="Azure AD token") azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider") diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py index f9ad763..0ef64e0 100644 --- a/libs/ktem/ktem/llms/manager.py +++ b/libs/ktem/ktem/llms/manager.py @@ -140,8 +140,17 @@ class LLMManager: def add(self, name: str, spec: dict, default: bool): """Add a new model to the pool""" + if not name: + raise ValueError("Name must not be empty") + try: with Session(engine) as session: + + if default: + # turn all models to non-default + session.query(LLMTable).update({"default": False}) + session.commit() + item = LLMTable(name=name, spec=spec, default=default) session.add(item) session.commit() @@ -164,6 +173,9 @@ class LLMManager: def update(self, name: str, spec: dict, default: bool): """Update a model in the pool""" + if not name: + raise ValueError("Name must not be empty") + try: with Session(engine) as session: diff --git a/libs/ktem/ktem/llms/ui.py b/libs/ktem/ktem/llms/ui.py index 7644e27..02d13a2 100644 --- a/libs/ktem/ktem/llms/ui.py +++ b/libs/ktem/ktem/llms/ui.py @@ -140,7 +140,7 @@ class LLMManagement(BasePage): self.create_llm, inputs=[self.name, self.llm_choices, self.spec, self.default], outputs=None, - ).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then( + ).success(self.list_llms, inputs=None, outputs=[self.llm_list]).success( lambda: ("", None, "", False, self.spec_desc_default), outputs=[ self.name, @@ -229,7 +229,7 @@ class LLMManagement(BasePage): llms.add(name, spec=spec, default=default) gr.Info(f"LLM {name} created successfully") except Exception as e: - gr.Error(f"Failed to create LLM {name}: {e}") + raise gr.Error(f"Failed to create LLM {name}: {e}") def list_llms(self): """List the LLMs"""