Improve LLMs and Embedding models resources experience (#21)

* Fix inconsistent default values
* Disallow LLM's empty name. Handle LLM creation error on UI
This commit is contained in:
Duc Nguyen (john)
2024-04-11 07:50:53 +07:00
committed by GitHub
parent f3e82b2e70
commit 3ed50b0f10
3 changed files with 21 additions and 6 deletions

View File

@@ -19,7 +19,7 @@ class BaseOpenAIEmbeddings(BaseEmbeddings):
_dependencies = ["openai"] _dependencies = ["openai"]
api_key: str = Param(help="API key", required=True) api_key: str = Param(None, help="API key", required=True)
timeout: Optional[float] = Param(None, help="Timeout for the API request.") timeout: Optional[float] = Param(None, help="Timeout for the API request.")
max_retries: Optional[int] = Param( max_retries: Optional[int] = Param(
None, help="Maximum number of retries for the API request." None, help="Maximum number of retries for the API request."
@@ -88,6 +88,7 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
base_url: Optional[str] = Param(None, help="OpenAI base URL") base_url: Optional[str] = Param(None, help="OpenAI base URL")
organization: Optional[str] = Param(None, help="OpenAI organization") organization: Optional[str] = Param(None, help="OpenAI organization")
model: str = Param( model: str = Param(
None,
help=( help=(
"ID of the model to use. You can go to [Model overview](https://platform." "ID of the model to use. You can go to [Model overview](https://platform."
"openai.com/docs/models/overview) to see the available models." "openai.com/docs/models/overview) to see the available models."
@@ -131,14 +132,16 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings): class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
azure_endpoint: str = Param( azure_endpoint: str = Param(
None,
help=( help=(
"HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, " "HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, "
"azure_deployment, and api_version parameters are used to construct " "azure_deployment, and api_version parameters are used to construct "
"the full URL for the Azure OpenAI model." "the full URL for the Azure OpenAI model."
) ),
required=True,
) )
azure_deployment: str = Param(help="Azure deployment name", required=True) azure_deployment: str = Param(None, help="Azure deployment name", required=True)
api_version: str = Param(help="Azure model version", required=True) api_version: str = Param(None, help="Azure model version", required=True)
azure_ad_token: Optional[str] = Param(None, help="Azure AD token") azure_ad_token: Optional[str] = Param(None, help="Azure AD token")
azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider") azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider")

View File

@@ -140,8 +140,17 @@ class LLMManager:
def add(self, name: str, spec: dict, default: bool): def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool""" """Add a new model to the pool"""
if not name:
raise ValueError("Name must not be empty")
try: try:
with Session(engine) as session: with Session(engine) as session:
if default:
# turn all models to non-default
session.query(LLMTable).update({"default": False})
session.commit()
item = LLMTable(name=name, spec=spec, default=default) item = LLMTable(name=name, spec=spec, default=default)
session.add(item) session.add(item)
session.commit() session.commit()
@@ -164,6 +173,9 @@ class LLMManager:
def update(self, name: str, spec: dict, default: bool): def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool""" """Update a model in the pool"""
if not name:
raise ValueError("Name must not be empty")
try: try:
with Session(engine) as session: with Session(engine) as session:

View File

@@ -140,7 +140,7 @@ class LLMManagement(BasePage):
self.create_llm, self.create_llm,
inputs=[self.name, self.llm_choices, self.spec, self.default], inputs=[self.name, self.llm_choices, self.spec, self.default],
outputs=None, outputs=None,
).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then( ).success(self.list_llms, inputs=None, outputs=[self.llm_list]).success(
lambda: ("", None, "", False, self.spec_desc_default), lambda: ("", None, "", False, self.spec_desc_default),
outputs=[ outputs=[
self.name, self.name,
@@ -229,7 +229,7 @@ class LLMManagement(BasePage):
llms.add(name, spec=spec, default=default) llms.add(name, spec=spec, default=default)
gr.Info(f"LLM {name} created successfully") gr.Info(f"LLM {name} created successfully")
except Exception as e: except Exception as e:
gr.Error(f"Failed to create LLM {name}: {e}") raise gr.Error(f"Failed to create LLM {name}: {e}")
def list_llms(self): def list_llms(self):
"""List the LLMs""" """List the LLMs"""