diff --git a/flowsettings.py b/flowsettings.py index 99881a1..281fd63 100644 --- a/flowsettings.py +++ b/flowsettings.py @@ -175,6 +175,33 @@ if config("LOCAL_MODEL", default=""): "default": False, } +# additional LLM configurations +KH_LLMS["claude"] = { + "spec": { + "__type__": "kotaemon.llms.chats.LCAnthropicChat", + "model_name": "claude-3-5-sonnet-20240620", + "api_key": "your-key", + }, + "default": False, +} +KH_LLMS["gemini"] = { + "spec": { + "__type__": "kotaemon.llms.chats.LCGeminiChat", + "model_name": "gemini-1.5-pro", + "api_key": "your-key", + }, + "default": False, +} +KH_LLMS["groq"] = { + "spec": { + "__type__": "kotaemon.llms.ChatOpenAI", + "base_url": "https://api.groq.com/openai/v1", + "model": "llama-3.1-8b-instant", + "api_key": "your-key", + }, + "default": False, +} + KH_REASONINGS = [ "ktem.reasoning.simple.FullQAPipeline", "ktem.reasoning.simple.FullDecomposeQAPipeline", diff --git a/libs/kotaemon/kotaemon/llms/__init__.py b/libs/kotaemon/kotaemon/llms/__init__.py index 6494fc9..03d517d 100644 --- a/libs/kotaemon/kotaemon/llms/__init__.py +++ b/libs/kotaemon/kotaemon/llms/__init__.py @@ -10,6 +10,7 @@ from .chats import ( LCAnthropicChat, LCAzureChatOpenAI, LCChatOpenAI, + LCGeminiChat, LlamaCppChat, ) from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI @@ -29,6 +30,7 @@ __all__ = [ "AzureChatOpenAI", "ChatOpenAI", "LCAnthropicChat", + "LCGeminiChat", "LCAzureChatOpenAI", "LCChatOpenAI", "LlamaCppChat", diff --git a/libs/kotaemon/kotaemon/llms/chats/__init__.py b/libs/kotaemon/kotaemon/llms/chats/__init__.py index 6e3d3d5..7f82c9a 100644 --- a/libs/kotaemon/kotaemon/llms/chats/__init__.py +++ b/libs/kotaemon/kotaemon/llms/chats/__init__.py @@ -5,6 +5,7 @@ from .langchain_based import ( LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI, + LCGeminiChat, ) from .llamacpp import LlamaCppChat from .openai import AzureChatOpenAI, ChatOpenAI @@ -16,6 +17,7 @@ __all__ = [ "EndpointChatLLM", "ChatOpenAI", "LCAnthropicChat", + "LCGeminiChat", "LCChatOpenAI", "LCAzureChatOpenAI", "LCChatMixin", diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py index 077c3f8..663c195 100644 --- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py @@ -245,3 +245,27 @@ class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore raise ImportError("Please install langchain-anthropic") return ChatAnthropic + + +class LCGeminiChat(LCChatMixin, ChatLLM): # type: ignore + def __init__( + self, + api_key: str | None = None, + model_name: str | None = None, + temperature: float = 0.7, + **params, + ): + super().__init__( + google_api_key=api_key, + model=model_name, + temperature=temperature, + **params, + ) + + def _get_lc_class(self): + try: + from langchain_google_genai import ChatGoogleGenerativeAI + except ImportError: + raise ImportError("Please install langchain-google-genai") + + return ChatGoogleGenerativeAI diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index e3299b7..afd3854 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -30,6 +30,8 @@ dependencies = [ "langchain>=0.1.16,<0.2.0", "langchain-community>=0.0.34,<0.1.0", "langchain-openai>=0.1.4,<0.2.0", + "langchain-anthropic", + "langchain-google-genai", "llama-hub>=0.0.79,<0.1.0", "llama-index>=0.10.40,<0.11.0", "llama-index-vector-stores-chroma>=0.1.9", diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py index f4b204f..3ac90a5 100644 --- a/libs/ktem/ktem/llms/manager.py +++ b/libs/ktem/ktem/llms/manager.py @@ -54,9 +54,21 @@ class LLMManager: self._default = item.name def load_vendors(self): - from kotaemon.llms import AzureChatOpenAI, ChatOpenAI, LlamaCppChat + from kotaemon.llms import ( + AzureChatOpenAI, + ChatOpenAI, + LCAnthropicChat, + LCGeminiChat, + LlamaCppChat, + ) - self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat] + self._vendors = [ + ChatOpenAI, + AzureChatOpenAI, + LCAnthropicChat, + LCGeminiChat, + LlamaCppChat, + ] for extra_vendor in getattr(flowsettings, "KH_LLM_EXTRA_VENDORS", []): self._vendors.append(import_dotted_string(extra_vendor, safe=False))