allow LlamaCppChat to auto download model from hf hub (#29)

This commit is contained in:
ian_Cin 2024-04-13 18:57:04 +07:00 committed by GitHub
parent 917fb0a082
commit 4022af7e9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 17 deletions

View File

@ -12,9 +12,14 @@ if TYPE_CHECKING:
class LlamaCppChat(ChatLLM): class LlamaCppChat(ChatLLM):
"""Wrapper around the llama-cpp-python's Llama model""" """Wrapper around the llama-cpp-python's Llama model"""
model_path: str = Param( model_path: Optional[str] = Param(
help="Path to the model file. This is required to load the model.", help="Path to the model file. This is required to load the model.",
required=True, )
repo_id: Optional[str] = Param(
help="Id of a repo on the HuggingFace Hub in the form of `user_name/repo_name`."
)
filename: Optional[str] = Param(
help="A filename or glob pattern to match the model file in the repo."
) )
chat_format: str = Param( chat_format: str = Param(
help=( help=(
@ -28,7 +33,7 @@ class LlamaCppChat(ChatLLM):
n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model") n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
n_gpu_layers: Optional[int] = Param( n_gpu_layers: Optional[int] = Param(
0, 0,
help=("Number of layers to offload to GPU. If -1, all layers are offloaded"), help="Number of layers to offload to GPU. If -1, all layers are offloaded",
) )
use_mmap: Optional[bool] = Param( use_mmap: Optional[bool] = Param(
True, True,
@ -36,7 +41,7 @@ class LlamaCppChat(ChatLLM):
) )
vocab_only: Optional[bool] = Param( vocab_only: Optional[bool] = Param(
False, False,
help=("If True, only the vocabulary is loaded. This is useful for debugging."), help="If True, only the vocabulary is loaded. This is useful for debugging.",
) )
_role_mapper: dict[str, str] = { _role_mapper: dict[str, str] = {
@ -57,8 +62,11 @@ class LlamaCppChat(ChatLLM):
) )
errors = [] errors = []
if not self.model_path: if not self.model_path and (not self.repo_id or not self.filename):
errors.append("- `model_path` is required to load the model") errors.append(
"- `model_path` or `repo_id` and `filename` are required to load the"
" model"
)
if not self.chat_format: if not self.chat_format:
errors.append( errors.append(
@ -69,15 +77,27 @@ class LlamaCppChat(ChatLLM):
if errors: if errors:
raise ValueError("\n".join(errors)) raise ValueError("\n".join(errors))
return Llama( if self.model_path:
model_path=cast(str, self.model_path), return Llama(
chat_format=self.chat_format, model_path=cast(str, self.model_path),
lora_base=self.lora_base, chat_format=self.chat_format,
n_ctx=self.n_ctx, lora_base=self.lora_base,
n_gpu_layers=self.n_gpu_layers, n_ctx=self.n_ctx,
use_mmap=self.use_mmap, n_gpu_layers=self.n_gpu_layers,
vocab_only=self.vocab_only, use_mmap=self.use_mmap,
) vocab_only=self.vocab_only,
)
else:
return Llama.from_pretrained(
repo_id=self.repo_id,
filename=self.filename,
chat_format=self.chat_format,
lora_base=self.lora_base,
n_ctx=self.n_ctx,
n_gpu_layers=self.n_gpu_layers,
use_mmap=self.use_mmap,
vocab_only=self.vocab_only,
)
def prepare_message( def prepare_message(
self, messages: str | BaseMessage | list[BaseMessage] self, messages: str | BaseMessage | list[BaseMessage]

View File

@ -22,7 +22,7 @@ dependencies = [
"theflow", "theflow",
"llama-index>=0.9.0,<0.10.0", "llama-index>=0.9.0,<0.10.0",
"llama-hub", "llama-hub",
"gradio>=4.0.0,<=4.22.0", "gradio>=4.26.0",
"openpyxl", "openpyxl",
"cookiecutter", "cookiecutter",
"click", "click",
@ -34,6 +34,7 @@ dependencies = [
"unstructured", "unstructured",
"pypdf", "pypdf",
"html2text", "html2text",
"fastembed",
] ]
readme = "README.md" readme = "README.md"
license = { text = "MIT License" } license = { text = "MIT License" }

View File

@ -19,7 +19,6 @@ dependencies = [
"python-decouple", "python-decouple",
"sqlalchemy", "sqlalchemy",
"sqlmodel", "sqlmodel",
"fastembed",
"tiktoken", "tiktoken",
"gradio>=4.26.0", "gradio>=4.26.0",
"markdown", "markdown",