From 4022af7e9b90775f3e9da2bc92b2352c3b485591 Mon Sep 17 00:00:00 2001 From: ian_Cin Date: Sat, 13 Apr 2024 18:57:04 +0700 Subject: [PATCH] allow LlamaCppChat to auto download model from hf hub (#29) --- libs/kotaemon/kotaemon/llms/chats/llamacpp.py | 50 +++++++++++++------ libs/kotaemon/pyproject.toml | 3 +- libs/ktem/pyproject.toml | 1 - 3 files changed, 37 insertions(+), 17 deletions(-) diff --git a/libs/kotaemon/kotaemon/llms/chats/llamacpp.py b/libs/kotaemon/kotaemon/llms/chats/llamacpp.py index 7b8bee4..11f9abb 100644 --- a/libs/kotaemon/kotaemon/llms/chats/llamacpp.py +++ b/libs/kotaemon/kotaemon/llms/chats/llamacpp.py @@ -12,9 +12,14 @@ if TYPE_CHECKING: class LlamaCppChat(ChatLLM): """Wrapper around the llama-cpp-python's Llama model""" - model_path: str = Param( + model_path: Optional[str] = Param( help="Path to the model file. This is required to load the model.", - required=True, + ) + repo_id: Optional[str] = Param( + help="Id of a repo on the HuggingFace Hub in the form of `user_name/repo_name`." + ) + filename: Optional[str] = Param( + help="A filename or glob pattern to match the model file in the repo." ) chat_format: str = Param( help=( @@ -28,7 +33,7 @@ class LlamaCppChat(ChatLLM): n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model") n_gpu_layers: Optional[int] = Param( 0, - help=("Number of layers to offload to GPU. If -1, all layers are offloaded"), + help="Number of layers to offload to GPU. If -1, all layers are offloaded", ) use_mmap: Optional[bool] = Param( True, @@ -36,7 +41,7 @@ class LlamaCppChat(ChatLLM): ) vocab_only: Optional[bool] = Param( False, - help=("If True, only the vocabulary is loaded. This is useful for debugging."), + help="If True, only the vocabulary is loaded. This is useful for debugging.", ) _role_mapper: dict[str, str] = { @@ -57,8 +62,11 @@ class LlamaCppChat(ChatLLM): ) errors = [] - if not self.model_path: - errors.append("- `model_path` is required to load the model") + if not self.model_path and (not self.repo_id or not self.filename): + errors.append( + "- `model_path` or `repo_id` and `filename` are required to load the" + " model" + ) if not self.chat_format: errors.append( @@ -69,15 +77,27 @@ class LlamaCppChat(ChatLLM): if errors: raise ValueError("\n".join(errors)) - return Llama( - model_path=cast(str, self.model_path), - chat_format=self.chat_format, - lora_base=self.lora_base, - n_ctx=self.n_ctx, - n_gpu_layers=self.n_gpu_layers, - use_mmap=self.use_mmap, - vocab_only=self.vocab_only, - ) + if self.model_path: + return Llama( + model_path=cast(str, self.model_path), + chat_format=self.chat_format, + lora_base=self.lora_base, + n_ctx=self.n_ctx, + n_gpu_layers=self.n_gpu_layers, + use_mmap=self.use_mmap, + vocab_only=self.vocab_only, + ) + else: + return Llama.from_pretrained( + repo_id=self.repo_id, + filename=self.filename, + chat_format=self.chat_format, + lora_base=self.lora_base, + n_ctx=self.n_ctx, + n_gpu_layers=self.n_gpu_layers, + use_mmap=self.use_mmap, + vocab_only=self.vocab_only, + ) def prepare_message( self, messages: str | BaseMessage | list[BaseMessage] diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index 31b538f..2fd5ad7 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "theflow", "llama-index>=0.9.0,<0.10.0", "llama-hub", - "gradio>=4.0.0,<=4.22.0", + "gradio>=4.26.0", "openpyxl", "cookiecutter", "click", @@ -34,6 +34,7 @@ dependencies = [ "unstructured", "pypdf", "html2text", + "fastembed", ] readme = "README.md" license = { text = "MIT License" } diff --git a/libs/ktem/pyproject.toml b/libs/ktem/pyproject.toml index 92bb36e..1e66b3f 100644 --- a/libs/ktem/pyproject.toml +++ b/libs/ktem/pyproject.toml @@ -19,7 +19,6 @@ dependencies = [ "python-decouple", "sqlalchemy", "sqlmodel", - "fastembed", "tiktoken", "gradio>=4.26.0", "markdown",