feat: add Google embedding support & update setup (#550) bump:patch
This commit is contained in:
parent
159f4da7c9
commit
b016a84b97
|
@ -26,6 +26,7 @@ if not KH_APP_VERSION:
|
||||||
|
|
||||||
KH_ENABLE_FIRST_SETUP = True
|
KH_ENABLE_FIRST_SETUP = True
|
||||||
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
|
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
|
||||||
|
KH_OLLAMA_URL = config("KH_OLLAMA_URL", default="http://localhost:11434/v1/")
|
||||||
|
|
||||||
# App can be ran from anywhere and it's not trivial to decide where to store app data.
|
# App can be ran from anywhere and it's not trivial to decide where to store app data.
|
||||||
# So let's use the same directory as the flowsetting.py file.
|
# So let's use the same directory as the flowsetting.py file.
|
||||||
|
@ -162,7 +163,7 @@ if config("LOCAL_MODEL", default=""):
|
||||||
KH_LLMS["ollama"] = {
|
KH_LLMS["ollama"] = {
|
||||||
"spec": {
|
"spec": {
|
||||||
"__type__": "kotaemon.llms.ChatOpenAI",
|
"__type__": "kotaemon.llms.ChatOpenAI",
|
||||||
"base_url": "http://localhost:11434/v1/",
|
"base_url": KH_OLLAMA_URL,
|
||||||
"model": config("LOCAL_MODEL", default="llama3.1:8b"),
|
"model": config("LOCAL_MODEL", default="llama3.1:8b"),
|
||||||
"api_key": "ollama",
|
"api_key": "ollama",
|
||||||
},
|
},
|
||||||
|
@ -171,7 +172,7 @@ if config("LOCAL_MODEL", default=""):
|
||||||
KH_EMBEDDINGS["ollama"] = {
|
KH_EMBEDDINGS["ollama"] = {
|
||||||
"spec": {
|
"spec": {
|
||||||
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||||
"base_url": "http://localhost:11434/v1/",
|
"base_url": KH_OLLAMA_URL,
|
||||||
"model": config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"),
|
"model": config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"),
|
||||||
"api_key": "ollama",
|
"api_key": "ollama",
|
||||||
},
|
},
|
||||||
|
@ -195,11 +196,11 @@ KH_LLMS["claude"] = {
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
}
|
}
|
||||||
KH_LLMS["gemini"] = {
|
KH_LLMS["google"] = {
|
||||||
"spec": {
|
"spec": {
|
||||||
"__type__": "kotaemon.llms.chats.LCGeminiChat",
|
"__type__": "kotaemon.llms.chats.LCGeminiChat",
|
||||||
"model_name": "gemini-1.5-pro",
|
"model_name": "gemini-1.5-flash",
|
||||||
"api_key": "your-key",
|
"api_key": config("GOOGLE_API_KEY", default="your-key"),
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
}
|
}
|
||||||
|
@ -231,6 +232,13 @@ KH_EMBEDDINGS["cohere"] = {
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
}
|
}
|
||||||
|
KH_EMBEDDINGS["google"] = {
|
||||||
|
"spec": {
|
||||||
|
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
|
||||||
|
"model": "models/text-embedding-004",
|
||||||
|
"google_api_key": config("GOOGLE_API_KEY", default="your-key"),
|
||||||
|
}
|
||||||
|
}
|
||||||
# KH_EMBEDDINGS["huggingface"] = {
|
# KH_EMBEDDINGS["huggingface"] = {
|
||||||
# "spec": {
|
# "spec": {
|
||||||
# "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings",
|
# "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings",
|
||||||
|
|
|
@ -4,6 +4,7 @@ from .fastembed import FastEmbedEmbeddings
|
||||||
from .langchain_based import (
|
from .langchain_based import (
|
||||||
LCAzureOpenAIEmbeddings,
|
LCAzureOpenAIEmbeddings,
|
||||||
LCCohereEmbeddings,
|
LCCohereEmbeddings,
|
||||||
|
LCGoogleEmbeddings,
|
||||||
LCHuggingFaceEmbeddings,
|
LCHuggingFaceEmbeddings,
|
||||||
LCOpenAIEmbeddings,
|
LCOpenAIEmbeddings,
|
||||||
)
|
)
|
||||||
|
@ -18,6 +19,7 @@ __all__ = [
|
||||||
"LCAzureOpenAIEmbeddings",
|
"LCAzureOpenAIEmbeddings",
|
||||||
"LCCohereEmbeddings",
|
"LCCohereEmbeddings",
|
||||||
"LCHuggingFaceEmbeddings",
|
"LCHuggingFaceEmbeddings",
|
||||||
|
"LCGoogleEmbeddings",
|
||||||
"OpenAIEmbeddings",
|
"OpenAIEmbeddings",
|
||||||
"AzureOpenAIEmbeddings",
|
"AzureOpenAIEmbeddings",
|
||||||
"FastEmbedEmbeddings",
|
"FastEmbedEmbeddings",
|
||||||
|
|
|
@ -219,3 +219,38 @@ class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||||
|
|
||||||
return HuggingFaceBgeEmbeddings
|
return HuggingFaceBgeEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
class LCGoogleEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||||
|
"""Wrapper around Langchain's Google GenAI embedding, focusing on key parameters"""
|
||||||
|
|
||||||
|
google_api_key: str = Param(
|
||||||
|
help="API key (https://aistudio.google.com/app/apikey)",
|
||||||
|
default=None,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
model: str = Param(
|
||||||
|
help="Model name to use (https://ai.google.dev/gemini-api/docs/models/gemini#text-embedding-and-embedding)", # noqa
|
||||||
|
default="models/text-embedding-004",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str = "models/text-embedding-004",
|
||||||
|
google_api_key: Optional[str] = None,
|
||||||
|
**params,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
model=model,
|
||||||
|
google_api_key=google_api_key,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_lc_class(self):
|
||||||
|
try:
|
||||||
|
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install langchain-google-genai")
|
||||||
|
|
||||||
|
return GoogleGenerativeAIEmbeddings
|
||||||
|
|
|
@ -57,6 +57,7 @@ class EmbeddingManager:
|
||||||
AzureOpenAIEmbeddings,
|
AzureOpenAIEmbeddings,
|
||||||
FastEmbedEmbeddings,
|
FastEmbedEmbeddings,
|
||||||
LCCohereEmbeddings,
|
LCCohereEmbeddings,
|
||||||
|
LCGoogleEmbeddings,
|
||||||
LCHuggingFaceEmbeddings,
|
LCHuggingFaceEmbeddings,
|
||||||
OpenAIEmbeddings,
|
OpenAIEmbeddings,
|
||||||
TeiEndpointEmbeddings,
|
TeiEndpointEmbeddings,
|
||||||
|
@ -68,6 +69,7 @@ class EmbeddingManager:
|
||||||
FastEmbedEmbeddings,
|
FastEmbedEmbeddings,
|
||||||
LCCohereEmbeddings,
|
LCCohereEmbeddings,
|
||||||
LCHuggingFaceEmbeddings,
|
LCHuggingFaceEmbeddings,
|
||||||
|
LCGoogleEmbeddings,
|
||||||
TeiEndpointEmbeddings,
|
TeiEndpointEmbeddings,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,10 @@ from ktem.rerankings.manager import reranking_models_manager as rerankers
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
|
||||||
DEFAULT_OLLAMA_URL = "http://localhost:11434/api"
|
KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/")
|
||||||
|
DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api")
|
||||||
|
if DEFAULT_OLLAMA_URL.endswith("/"):
|
||||||
|
DEFAULT_OLLAMA_URL = DEFAULT_OLLAMA_URL[:-1]
|
||||||
|
|
||||||
|
|
||||||
DEMO_MESSAGE = (
|
DEMO_MESSAGE = (
|
||||||
|
@ -55,8 +58,9 @@ class SetupPage(BasePage):
|
||||||
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
|
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
|
||||||
self.radio_model = gr.Radio(
|
self.radio_model = gr.Radio(
|
||||||
[
|
[
|
||||||
("Cohere API (*free registration* available) - recommended", "cohere"),
|
("Cohere API (*free registration*) - recommended", "cohere"),
|
||||||
("OpenAI API (for more advance models)", "openai"),
|
("Google API (*free registration*)", "google"),
|
||||||
|
("OpenAI API (for GPT-based models)", "openai"),
|
||||||
("Local LLM (for completely *private RAG*)", "ollama"),
|
("Local LLM (for completely *private RAG*)", "ollama"),
|
||||||
],
|
],
|
||||||
label="Select your model provider",
|
label="Select your model provider",
|
||||||
|
@ -92,6 +96,18 @@ class SetupPage(BasePage):
|
||||||
show_label=False, placeholder="Cohere API Key"
|
show_label=False, placeholder="Cohere API Key"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Column(visible=False) as self.google_option:
|
||||||
|
gr.Markdown(
|
||||||
|
(
|
||||||
|
"#### Google API Key\n\n"
|
||||||
|
"(register your free API key "
|
||||||
|
"at https://aistudio.google.com/app/apikey)"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.google_api_key = gr.Textbox(
|
||||||
|
show_label=False, placeholder="Google API Key"
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Column(visible=False) as self.ollama_option:
|
with gr.Column(visible=False) as self.ollama_option:
|
||||||
gr.Markdown(
|
gr.Markdown(
|
||||||
(
|
(
|
||||||
|
@ -119,7 +135,12 @@ class SetupPage(BasePage):
|
||||||
self.openai_api_key.submit,
|
self.openai_api_key.submit,
|
||||||
],
|
],
|
||||||
fn=self.update_model,
|
fn=self.update_model,
|
||||||
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model],
|
inputs=[
|
||||||
|
self.cohere_api_key,
|
||||||
|
self.openai_api_key,
|
||||||
|
self.google_api_key,
|
||||||
|
self.radio_model,
|
||||||
|
],
|
||||||
outputs=[self.setup_log],
|
outputs=[self.setup_log],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
@ -147,13 +168,19 @@ class SetupPage(BasePage):
|
||||||
fn=self.switch_options_view,
|
fn=self.switch_options_view,
|
||||||
inputs=[self.radio_model],
|
inputs=[self.radio_model],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
outputs=[self.cohere_option, self.openai_option, self.ollama_option],
|
outputs=[
|
||||||
|
self.cohere_option,
|
||||||
|
self.openai_option,
|
||||||
|
self.ollama_option,
|
||||||
|
self.google_option,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_model(
|
def update_model(
|
||||||
self,
|
self,
|
||||||
cohere_api_key,
|
cohere_api_key,
|
||||||
openai_api_key,
|
openai_api_key,
|
||||||
|
google_api_key,
|
||||||
radio_model_value,
|
radio_model_value,
|
||||||
):
|
):
|
||||||
# skip if KH_DEMO_MODE
|
# skip if KH_DEMO_MODE
|
||||||
|
@ -221,12 +248,32 @@ class SetupPage(BasePage):
|
||||||
},
|
},
|
||||||
default=True,
|
default=True,
|
||||||
)
|
)
|
||||||
|
elif radio_model_value == "google":
|
||||||
|
if google_api_key:
|
||||||
|
llms.update(
|
||||||
|
name="google",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.llms.chats.LCGeminiChat",
|
||||||
|
"model_name": "gemini-1.5-flash",
|
||||||
|
"api_key": google_api_key,
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
embeddings.update(
|
||||||
|
name="google",
|
||||||
|
spec={
|
||||||
|
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
|
||||||
|
"model": "models/text-embedding-004",
|
||||||
|
"google_api_key": google_api_key,
|
||||||
|
},
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
elif radio_model_value == "ollama":
|
elif radio_model_value == "ollama":
|
||||||
llms.update(
|
llms.update(
|
||||||
name="ollama",
|
name="ollama",
|
||||||
spec={
|
spec={
|
||||||
"__type__": "kotaemon.llms.ChatOpenAI",
|
"__type__": "kotaemon.llms.ChatOpenAI",
|
||||||
"base_url": "http://localhost:11434/v1/",
|
"base_url": KH_OLLAMA_URL,
|
||||||
"model": "llama3.1:8b",
|
"model": "llama3.1:8b",
|
||||||
"api_key": "ollama",
|
"api_key": "ollama",
|
||||||
},
|
},
|
||||||
|
@ -236,7 +283,7 @@ class SetupPage(BasePage):
|
||||||
name="ollama",
|
name="ollama",
|
||||||
spec={
|
spec={
|
||||||
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||||
"base_url": "http://localhost:11434/v1/",
|
"base_url": KH_OLLAMA_URL,
|
||||||
"model": "nomic-embed-text",
|
"model": "nomic-embed-text",
|
||||||
"api_key": "ollama",
|
"api_key": "ollama",
|
||||||
},
|
},
|
||||||
|
@ -345,9 +392,9 @@ class SetupPage(BasePage):
|
||||||
return default_settings
|
return default_settings
|
||||||
|
|
||||||
def switch_options_view(self, radio_model_value):
|
def switch_options_view(self, radio_model_value):
|
||||||
components_visible = [gr.update(visible=False) for _ in range(3)]
|
components_visible = [gr.update(visible=False) for _ in range(4)]
|
||||||
|
|
||||||
values = ["cohere", "openai", "ollama", None]
|
values = ["cohere", "openai", "ollama", "google", None]
|
||||||
assert radio_model_value in values, f"Invalid value {radio_model_value}"
|
assert radio_model_value in values, f"Invalid value {radio_model_value}"
|
||||||
|
|
||||||
if radio_model_value is not None:
|
if radio_model_value is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user