feat: add Google embedding support & update setup (#550) bump:patch

This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2024-12-04 11:09:57 +07:00 committed by GitHub
parent 159f4da7c9
commit b016a84b97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 109 additions and 15 deletions

View File

@ -26,6 +26,7 @@ if not KH_APP_VERSION:
KH_ENABLE_FIRST_SETUP = True KH_ENABLE_FIRST_SETUP = True
KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool) KH_DEMO_MODE = config("KH_DEMO_MODE", default=False, cast=bool)
KH_OLLAMA_URL = config("KH_OLLAMA_URL", default="http://localhost:11434/v1/")
# App can be ran from anywhere and it's not trivial to decide where to store app data. # App can be ran from anywhere and it's not trivial to decide where to store app data.
# So let's use the same directory as the flowsetting.py file. # So let's use the same directory as the flowsetting.py file.
@ -162,7 +163,7 @@ if config("LOCAL_MODEL", default=""):
KH_LLMS["ollama"] = { KH_LLMS["ollama"] = {
"spec": { "spec": {
"__type__": "kotaemon.llms.ChatOpenAI", "__type__": "kotaemon.llms.ChatOpenAI",
"base_url": "http://localhost:11434/v1/", "base_url": KH_OLLAMA_URL,
"model": config("LOCAL_MODEL", default="llama3.1:8b"), "model": config("LOCAL_MODEL", default="llama3.1:8b"),
"api_key": "ollama", "api_key": "ollama",
}, },
@ -171,7 +172,7 @@ if config("LOCAL_MODEL", default=""):
KH_EMBEDDINGS["ollama"] = { KH_EMBEDDINGS["ollama"] = {
"spec": { "spec": {
"__type__": "kotaemon.embeddings.OpenAIEmbeddings", "__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": "http://localhost:11434/v1/", "base_url": KH_OLLAMA_URL,
"model": config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"), "model": config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"),
"api_key": "ollama", "api_key": "ollama",
}, },
@ -195,11 +196,11 @@ KH_LLMS["claude"] = {
}, },
"default": False, "default": False,
} }
KH_LLMS["gemini"] = { KH_LLMS["google"] = {
"spec": { "spec": {
"__type__": "kotaemon.llms.chats.LCGeminiChat", "__type__": "kotaemon.llms.chats.LCGeminiChat",
"model_name": "gemini-1.5-pro", "model_name": "gemini-1.5-flash",
"api_key": "your-key", "api_key": config("GOOGLE_API_KEY", default="your-key"),
}, },
"default": False, "default": False,
} }
@ -231,6 +232,13 @@ KH_EMBEDDINGS["cohere"] = {
}, },
"default": False, "default": False,
} }
KH_EMBEDDINGS["google"] = {
"spec": {
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
"model": "models/text-embedding-004",
"google_api_key": config("GOOGLE_API_KEY", default="your-key"),
}
}
# KH_EMBEDDINGS["huggingface"] = { # KH_EMBEDDINGS["huggingface"] = {
# "spec": { # "spec": {
# "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings", # "__type__": "kotaemon.embeddings.LCHuggingFaceEmbeddings",

View File

@ -4,6 +4,7 @@ from .fastembed import FastEmbedEmbeddings
from .langchain_based import ( from .langchain_based import (
LCAzureOpenAIEmbeddings, LCAzureOpenAIEmbeddings,
LCCohereEmbeddings, LCCohereEmbeddings,
LCGoogleEmbeddings,
LCHuggingFaceEmbeddings, LCHuggingFaceEmbeddings,
LCOpenAIEmbeddings, LCOpenAIEmbeddings,
) )
@ -18,6 +19,7 @@ __all__ = [
"LCAzureOpenAIEmbeddings", "LCAzureOpenAIEmbeddings",
"LCCohereEmbeddings", "LCCohereEmbeddings",
"LCHuggingFaceEmbeddings", "LCHuggingFaceEmbeddings",
"LCGoogleEmbeddings",
"OpenAIEmbeddings", "OpenAIEmbeddings",
"AzureOpenAIEmbeddings", "AzureOpenAIEmbeddings",
"FastEmbedEmbeddings", "FastEmbedEmbeddings",

View File

@ -219,3 +219,38 @@ class LCHuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings
return HuggingFaceBgeEmbeddings return HuggingFaceBgeEmbeddings
class LCGoogleEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
"""Wrapper around Langchain's Google GenAI embedding, focusing on key parameters"""
google_api_key: str = Param(
help="API key (https://aistudio.google.com/app/apikey)",
default=None,
required=True,
)
model: str = Param(
help="Model name to use (https://ai.google.dev/gemini-api/docs/models/gemini#text-embedding-and-embedding)", # noqa
default="models/text-embedding-004",
required=True,
)
def __init__(
self,
model: str = "models/text-embedding-004",
google_api_key: Optional[str] = None,
**params,
):
super().__init__(
model=model,
google_api_key=google_api_key,
**params,
)
def _get_lc_class(self):
try:
from langchain_google_genai import GoogleGenerativeAIEmbeddings
except ImportError:
raise ImportError("Please install langchain-google-genai")
return GoogleGenerativeAIEmbeddings

View File

@ -57,6 +57,7 @@ class EmbeddingManager:
AzureOpenAIEmbeddings, AzureOpenAIEmbeddings,
FastEmbedEmbeddings, FastEmbedEmbeddings,
LCCohereEmbeddings, LCCohereEmbeddings,
LCGoogleEmbeddings,
LCHuggingFaceEmbeddings, LCHuggingFaceEmbeddings,
OpenAIEmbeddings, OpenAIEmbeddings,
TeiEndpointEmbeddings, TeiEndpointEmbeddings,
@ -68,6 +69,7 @@ class EmbeddingManager:
FastEmbedEmbeddings, FastEmbedEmbeddings,
LCCohereEmbeddings, LCCohereEmbeddings,
LCHuggingFaceEmbeddings, LCHuggingFaceEmbeddings,
LCGoogleEmbeddings,
TeiEndpointEmbeddings, TeiEndpointEmbeddings,
] ]

View File

@ -9,7 +9,10 @@ from ktem.rerankings.manager import reranking_models_manager as rerankers
from theflow.settings import settings as flowsettings from theflow.settings import settings as flowsettings
KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False) KH_DEMO_MODE = getattr(flowsettings, "KH_DEMO_MODE", False)
DEFAULT_OLLAMA_URL = "http://localhost:11434/api" KH_OLLAMA_URL = getattr(flowsettings, "KH_OLLAMA_URL", "http://localhost:11434/v1/")
DEFAULT_OLLAMA_URL = KH_OLLAMA_URL.replace("v1", "api")
if DEFAULT_OLLAMA_URL.endswith("/"):
DEFAULT_OLLAMA_URL = DEFAULT_OLLAMA_URL[:-1]
DEMO_MESSAGE = ( DEMO_MESSAGE = (
@ -55,8 +58,9 @@ class SetupPage(BasePage):
gr.Markdown(f"# Welcome to {self._app.app_name} first setup!") gr.Markdown(f"# Welcome to {self._app.app_name} first setup!")
self.radio_model = gr.Radio( self.radio_model = gr.Radio(
[ [
("Cohere API (*free registration* available) - recommended", "cohere"), ("Cohere API (*free registration*) - recommended", "cohere"),
("OpenAI API (for more advance models)", "openai"), ("Google API (*free registration*)", "google"),
("OpenAI API (for GPT-based models)", "openai"),
("Local LLM (for completely *private RAG*)", "ollama"), ("Local LLM (for completely *private RAG*)", "ollama"),
], ],
label="Select your model provider", label="Select your model provider",
@ -92,6 +96,18 @@ class SetupPage(BasePage):
show_label=False, placeholder="Cohere API Key" show_label=False, placeholder="Cohere API Key"
) )
with gr.Column(visible=False) as self.google_option:
gr.Markdown(
(
"#### Google API Key\n\n"
"(register your free API key "
"at https://aistudio.google.com/app/apikey)"
)
)
self.google_api_key = gr.Textbox(
show_label=False, placeholder="Google API Key"
)
with gr.Column(visible=False) as self.ollama_option: with gr.Column(visible=False) as self.ollama_option:
gr.Markdown( gr.Markdown(
( (
@ -119,7 +135,12 @@ class SetupPage(BasePage):
self.openai_api_key.submit, self.openai_api_key.submit,
], ],
fn=self.update_model, fn=self.update_model,
inputs=[self.cohere_api_key, self.openai_api_key, self.radio_model], inputs=[
self.cohere_api_key,
self.openai_api_key,
self.google_api_key,
self.radio_model,
],
outputs=[self.setup_log], outputs=[self.setup_log],
show_progress="hidden", show_progress="hidden",
) )
@ -147,13 +168,19 @@ class SetupPage(BasePage):
fn=self.switch_options_view, fn=self.switch_options_view,
inputs=[self.radio_model], inputs=[self.radio_model],
show_progress="hidden", show_progress="hidden",
outputs=[self.cohere_option, self.openai_option, self.ollama_option], outputs=[
self.cohere_option,
self.openai_option,
self.ollama_option,
self.google_option,
],
) )
def update_model( def update_model(
self, self,
cohere_api_key, cohere_api_key,
openai_api_key, openai_api_key,
google_api_key,
radio_model_value, radio_model_value,
): ):
# skip if KH_DEMO_MODE # skip if KH_DEMO_MODE
@ -221,12 +248,32 @@ class SetupPage(BasePage):
}, },
default=True, default=True,
) )
elif radio_model_value == "google":
if google_api_key:
llms.update(
name="google",
spec={
"__type__": "kotaemon.llms.chats.LCGeminiChat",
"model_name": "gemini-1.5-flash",
"api_key": google_api_key,
},
default=True,
)
embeddings.update(
name="google",
spec={
"__type__": "kotaemon.embeddings.LCGoogleEmbeddings",
"model": "models/text-embedding-004",
"google_api_key": google_api_key,
},
default=True,
)
elif radio_model_value == "ollama": elif radio_model_value == "ollama":
llms.update( llms.update(
name="ollama", name="ollama",
spec={ spec={
"__type__": "kotaemon.llms.ChatOpenAI", "__type__": "kotaemon.llms.ChatOpenAI",
"base_url": "http://localhost:11434/v1/", "base_url": KH_OLLAMA_URL,
"model": "llama3.1:8b", "model": "llama3.1:8b",
"api_key": "ollama", "api_key": "ollama",
}, },
@ -236,7 +283,7 @@ class SetupPage(BasePage):
name="ollama", name="ollama",
spec={ spec={
"__type__": "kotaemon.embeddings.OpenAIEmbeddings", "__type__": "kotaemon.embeddings.OpenAIEmbeddings",
"base_url": "http://localhost:11434/v1/", "base_url": KH_OLLAMA_URL,
"model": "nomic-embed-text", "model": "nomic-embed-text",
"api_key": "ollama", "api_key": "ollama",
}, },
@ -270,7 +317,7 @@ class SetupPage(BasePage):
yield log_content yield log_content
except Exception as e: except Exception as e:
log_content += ( log_content += (
"Make sure you have download and installed Ollama correctly." "Make sure you have download and installed Ollama correctly. "
f"Got error: {str(e)}" f"Got error: {str(e)}"
) )
yield log_content yield log_content
@ -345,9 +392,9 @@ class SetupPage(BasePage):
return default_settings return default_settings
def switch_options_view(self, radio_model_value): def switch_options_view(self, radio_model_value):
components_visible = [gr.update(visible=False) for _ in range(3)] components_visible = [gr.update(visible=False) for _ in range(4)]
values = ["cohere", "openai", "ollama", None] values = ["cohere", "openai", "ollama", "google", None]
assert radio_model_value in values, f"Invalid value {radio_model_value}" assert radio_model_value in values, f"Invalid value {radio_model_value}"
if radio_model_value is not None: if radio_model_value is not None: