Improve LLMs and Embedding models resources experience (#21)
* Fix inconsistent default values * Disallow LLM's empty name. Handle LLM creation error on UI
This commit is contained in:
committed by
GitHub
parent
f3e82b2e70
commit
3ed50b0f10
@@ -19,7 +19,7 @@ class BaseOpenAIEmbeddings(BaseEmbeddings):
|
|||||||
|
|
||||||
_dependencies = ["openai"]
|
_dependencies = ["openai"]
|
||||||
|
|
||||||
api_key: str = Param(help="API key", required=True)
|
api_key: str = Param(None, help="API key", required=True)
|
||||||
timeout: Optional[float] = Param(None, help="Timeout for the API request.")
|
timeout: Optional[float] = Param(None, help="Timeout for the API request.")
|
||||||
max_retries: Optional[int] = Param(
|
max_retries: Optional[int] = Param(
|
||||||
None, help="Maximum number of retries for the API request."
|
None, help="Maximum number of retries for the API request."
|
||||||
@@ -88,6 +88,7 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
|
|||||||
base_url: Optional[str] = Param(None, help="OpenAI base URL")
|
base_url: Optional[str] = Param(None, help="OpenAI base URL")
|
||||||
organization: Optional[str] = Param(None, help="OpenAI organization")
|
organization: Optional[str] = Param(None, help="OpenAI organization")
|
||||||
model: str = Param(
|
model: str = Param(
|
||||||
|
None,
|
||||||
help=(
|
help=(
|
||||||
"ID of the model to use. You can go to [Model overview](https://platform."
|
"ID of the model to use. You can go to [Model overview](https://platform."
|
||||||
"openai.com/docs/models/overview) to see the available models."
|
"openai.com/docs/models/overview) to see the available models."
|
||||||
@@ -131,14 +132,16 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
|
|||||||
|
|
||||||
class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
|
class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
|
||||||
azure_endpoint: str = Param(
|
azure_endpoint: str = Param(
|
||||||
|
None,
|
||||||
help=(
|
help=(
|
||||||
"HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, "
|
"HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, "
|
||||||
"azure_deployment, and api_version parameters are used to construct "
|
"azure_deployment, and api_version parameters are used to construct "
|
||||||
"the full URL for the Azure OpenAI model."
|
"the full URL for the Azure OpenAI model."
|
||||||
)
|
),
|
||||||
|
required=True,
|
||||||
)
|
)
|
||||||
azure_deployment: str = Param(help="Azure deployment name", required=True)
|
azure_deployment: str = Param(None, help="Azure deployment name", required=True)
|
||||||
api_version: str = Param(help="Azure model version", required=True)
|
api_version: str = Param(None, help="Azure model version", required=True)
|
||||||
azure_ad_token: Optional[str] = Param(None, help="Azure AD token")
|
azure_ad_token: Optional[str] = Param(None, help="Azure AD token")
|
||||||
azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider")
|
azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider")
|
||||||
|
|
||||||
|
@@ -140,8 +140,17 @@ class LLMManager:
|
|||||||
|
|
||||||
def add(self, name: str, spec: dict, default: bool):
|
def add(self, name: str, spec: dict, default: bool):
|
||||||
"""Add a new model to the pool"""
|
"""Add a new model to the pool"""
|
||||||
|
if not name:
|
||||||
|
raise ValueError("Name must not be empty")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
|
|
||||||
|
if default:
|
||||||
|
# turn all models to non-default
|
||||||
|
session.query(LLMTable).update({"default": False})
|
||||||
|
session.commit()
|
||||||
|
|
||||||
item = LLMTable(name=name, spec=spec, default=default)
|
item = LLMTable(name=name, spec=spec, default=default)
|
||||||
session.add(item)
|
session.add(item)
|
||||||
session.commit()
|
session.commit()
|
||||||
@@ -164,6 +173,9 @@ class LLMManager:
|
|||||||
|
|
||||||
def update(self, name: str, spec: dict, default: bool):
|
def update(self, name: str, spec: dict, default: bool):
|
||||||
"""Update a model in the pool"""
|
"""Update a model in the pool"""
|
||||||
|
if not name:
|
||||||
|
raise ValueError("Name must not be empty")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
|
|
||||||
|
@@ -140,7 +140,7 @@ class LLMManagement(BasePage):
|
|||||||
self.create_llm,
|
self.create_llm,
|
||||||
inputs=[self.name, self.llm_choices, self.spec, self.default],
|
inputs=[self.name, self.llm_choices, self.spec, self.default],
|
||||||
outputs=None,
|
outputs=None,
|
||||||
).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then(
|
).success(self.list_llms, inputs=None, outputs=[self.llm_list]).success(
|
||||||
lambda: ("", None, "", False, self.spec_desc_default),
|
lambda: ("", None, "", False, self.spec_desc_default),
|
||||||
outputs=[
|
outputs=[
|
||||||
self.name,
|
self.name,
|
||||||
@@ -229,7 +229,7 @@ class LLMManagement(BasePage):
|
|||||||
llms.add(name, spec=spec, default=default)
|
llms.add(name, spec=spec, default=default)
|
||||||
gr.Info(f"LLM {name} created successfully")
|
gr.Info(f"LLM {name} created successfully")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
gr.Error(f"Failed to create LLM {name}: {e}")
|
raise gr.Error(f"Failed to create LLM {name}: {e}")
|
||||||
|
|
||||||
def list_llms(self):
|
def list_llms(self):
|
||||||
"""List the LLMs"""
|
"""List the LLMs"""
|
||||||
|
Reference in New Issue
Block a user