feat: add Google embedding support & update setup (#550) bump:patch
This commit is contained in:
committed by
GitHub
parent
159f4da7c9
commit
b016a84b97
@@ -57,6 +57,7 @@ class EmbeddingManager:
|
||||
AzureOpenAIEmbeddings,
|
||||
FastEmbedEmbeddings,
|
||||
LCCohereEmbeddings,
|
||||
LCGoogleEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
OpenAIEmbeddings,
|
||||
TeiEndpointEmbeddings,
|
||||
@@ -68,6 +69,7 @@ class EmbeddingManager:
|
||||
FastEmbedEmbeddings,
|
||||
LCCohereEmbeddings,
|
||||
LCHuggingFaceEmbeddings,
|
||||
LCGoogleEmbeddings,
|
||||
TeiEndpointEmbeddings,
|
||||
]
|
||||
|
||||
|
@@ -9,7 +9,10 @@ from ktem.rerankings.manager import reranking_models_manager as rerankers
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||
DEFAULT_OLLAMA_URL = "http://localhost:11434/api"
|
||||
KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/")
|
||||
DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api")
|
||||
if DEFAULT_OLLAMA_URL.endswith("/"):
|
||||
DEFAULT_OLLAMA_URL = DEFAULT_OLLAMA_URL[:-1]
|
||||
|
||||
|
||||
DEMO_MESSAGE = (
|
||||
@@ -55,8 +58,9 @@ class SetupPage(BasePage):
|
||||
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
|
||||
self.radio_model = gr.Radio(
|
||||
[
|
||||
("Cohere API (*free registration* available) - recommended", "cohere"),
|
||||
("OpenAI API (for more advance models)", "openai"),
|
||||
("Cohere API (*free registration*) - recommended", "cohere"),
|
||||
("Google API (*free registration*)", "google"),
|
||||
("OpenAI API (for GPT-based models)", "openai"),
|
||||
("Local LLM (for completely *private RAG*)", "ollama"),
|
||||
],
|
||||
label="Select your model provider",
|
||||
@@ -92,6 +96,18 @@ class SetupPage(BasePage):
|
||||
show_label=False, placeholder="Cohere API Key"
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self.google_option:
|
||||
gr.Markdown(
|
||||
(
|
||||
"#### Google API Key\n\n"
|
||||
"(register your free API key "
|
||||
"at https://aistudio.google.com/app/apikey)"
|
||||
)
|
||||
)
|
||||
self.google_api_key = gr.Textbox(
|
||||
show_label=False, placeholder="Google API Key"
|
||||
)
|
||||
|
||||
with gr.Column(visible=False) as self.ollama_option:
|
||||
gr.Markdown(
|
||||
(
|
||||
@@ -119,7 +135,12 @@ class SetupPage(BasePage):
|
||||
self.openai_api_key.submit,
|
||||
],
|
||||
fn=self.update_model,
|
||||
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model],
|
||||
inputs=[
|
||||
self.cohere_api_key,
|
||||
self.openai_api_key,
|
||||
self.google_api_key,
|
||||
self.radio_model,
|
||||
],
|
||||
outputs=[self.setup_log],
|
||||
show_progress="hidden",
|
||||
)
|
||||
@@ -147,13 +168,19 @@ class SetupPage(BasePage):
|
||||
fn=self.switch_options_view,
|
||||
inputs=[self.radio_model],
|
||||
show_progress="hidden",
|
||||
outputs=[self.cohere_option, self.openai_option, self.ollama_option],
|
||||
outputs=[
|
||||
self.cohere_option,
|
||||
self.openai_option,
|
||||
self.ollama_option,
|
||||
self.google_option,
|
||||
],
|
||||
)
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
cohere_api_key,
|
||||
openai_api_key,
|
||||
google_api_key,
|
||||
radio_model_value,
|
||||
):
|
||||
# skip if KH_DEMO_MODE
|
||||
@@ -221,12 +248,32 @@ class SetupPage(BasePage):
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
elif radio_model_value == "google":
|
||||
if google_api_key:
|
||||
llms.update(
|
||||
name="google",
|
||||
spec={
|
||||
"__type__": "kotaemon.llms.chats.LCGeminiChat",
|
||||
"model_name": "gemini-1.5-flash",
|
||||
"api_key": google_api_key,
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
embeddings.update(
|
||||
name="google",
|
||||
spec={
|
||||
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
|
||||
"model": "models/text-embedding-004",
|
||||
"google_api_key": google_api_key,
|
||||
},
|
||||
default=True,
|
||||
)
|
||||
elif radio_model_value == "ollama":
|
||||
llms.update(
|
||||
name="ollama",
|
||||
spec={
|
||||
"__type__": "kotaemon.llms.ChatOpenAI",
|
||||
"base_url": "http://localhost:11434/v1/",
|
||||
"base_url": KH_OLLAMA_URL,
|
||||
"model": "llama3.1:8b",
|
||||
"api_key": "ollama",
|
||||
},
|
||||
@@ -236,7 +283,7 @@ class SetupPage(BasePage):
|
||||
name="ollama",
|
||||
spec={
|
||||
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||
"base_url": "http://localhost:11434/v1/",
|
||||
"base_url": KH_OLLAMA_URL,
|
||||
"model": "nomic-embed-text",
|
||||
"api_key": "ollama",
|
||||
},
|
||||
@@ -270,7 +317,7 @@ class SetupPage(BasePage):
|
||||
yield log_content
|
||||
except Exception as e:
|
||||
log_content += (
|
||||
"Make sure you have download and installed Ollama correctly."
|
||||
"Make sure you have download and installed Ollama correctly. "
|
||||
f"Got error: {str(e)}"
|
||||
)
|
||||
yield log_content
|
||||
@@ -345,9 +392,9 @@ class SetupPage(BasePage):
|
||||
return default_settings
|
||||
|
||||
def switch_options_view(self, radio_model_value):
|
||||
components_visible = [gr.update(visible=False) for _ in range(3)]
|
||||
components_visible = [gr.update(visible=False) for _ in range(4)]
|
||||
|
||||
values = ["cohere", "openai", "ollama", None]
|
||||
values = ["cohere", "openai", "ollama", "google", None]
|
||||
assert radio_model_value in values, f"Invalid value {radio_model_value}"
|
||||
|
||||
if radio_model_value is not None:
|
||||
|
Reference in New Issue
Block a user