allow LlamaCppChat to auto download model from hf hub (#29)

This commit is contained in:
ian_Cin 2024-04-13 18:57:04 +07:00 committed by GitHub
parent 917fb0a082
commit 4022af7e9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 17 deletions

View File

@ -12,9 +12,14 @@ if TYPE_CHECKING:
class LlamaCppChat(ChatLLM):
"""Wrapper around the llama-cpp-python's Llama model"""
model_path: str = Param(
model_path: Optional[str] = Param(
help="Path to the model file. This is required to load the model.",
required=True,
)
repo_id: Optional[str] = Param(
help="Id of a repo on the HuggingFace Hub in the form of `user_name/repo_name`."
)
filename: Optional[str] = Param(
help="A filename or glob pattern to match the model file in the repo."
)
chat_format: str = Param(
help=(
@ -28,7 +33,7 @@ class LlamaCppChat(ChatLLM):
n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
n_gpu_layers: Optional[int] = Param(
0,
help=("Number of layers to offload to GPU. If -1, all layers are offloaded"),
help="Number of layers to offload to GPU. If -1, all layers are offloaded",
)
use_mmap: Optional[bool] = Param(
True,
@ -36,7 +41,7 @@ class LlamaCppChat(ChatLLM):
)
vocab_only: Optional[bool] = Param(
False,
help=("If True, only the vocabulary is loaded. This is useful for debugging."),
help="If True, only the vocabulary is loaded. This is useful for debugging.",
)
_role_mapper: dict[str, str] = {
@ -57,8 +62,11 @@ class LlamaCppChat(ChatLLM):
)
errors = []
if not self.model_path:
errors.append("- `model_path` is required to load the model")
if not self.model_path and (not self.repo_id or not self.filename):
errors.append(
"- `model_path` or `repo_id` and `filename` are required to load the"
" model"
)
if not self.chat_format:
errors.append(
@ -69,6 +77,7 @@ class LlamaCppChat(ChatLLM):
if errors:
raise ValueError("\n".join(errors))
if self.model_path:
return Llama(
model_path=cast(str, self.model_path),
chat_format=self.chat_format,
@ -78,6 +87,17 @@ class LlamaCppChat(ChatLLM):
use_mmap=self.use_mmap,
vocab_only=self.vocab_only,
)
else:
return Llama.from_pretrained(
repo_id=self.repo_id,
filename=self.filename,
chat_format=self.chat_format,
lora_base=self.lora_base,
n_ctx=self.n_ctx,
n_gpu_layers=self.n_gpu_layers,
use_mmap=self.use_mmap,
vocab_only=self.vocab_only,
)
def prepare_message(
self, messages: str | BaseMessage | list[BaseMessage]

View File

@ -22,7 +22,7 @@ dependencies = [
"theflow",
"llama-index>=0.9.0,<0.10.0",
"llama-hub",
"gradio>=4.0.0,<=4.22.0",
"gradio>=4.26.0",
"openpyxl",
"cookiecutter",
"click",
@ -34,6 +34,7 @@ dependencies = [
"unstructured",
"pypdf",
"html2text",
"fastembed",
]
readme = "README.md"
license = { text = "MIT License" }

View File

@ -19,7 +19,6 @@ dependencies = [
"python-decouple",
"sqlalchemy",
"sqlmodel",
"fastembed",
"tiktoken",
"gradio>=4.26.0",
"markdown",