allow LlamaCppChat to auto download model from hf hub (#29)
This commit is contained in:
parent
917fb0a082
commit
4022af7e9b
|
@ -12,9 +12,14 @@ if TYPE_CHECKING:
|
|||
class LlamaCppChat(ChatLLM):
|
||||
"""Wrapper around the llama-cpp-python's Llama model"""
|
||||
|
||||
model_path: str = Param(
|
||||
model_path: Optional[str] = Param(
|
||||
help="Path to the model file. This is required to load the model.",
|
||||
required=True,
|
||||
)
|
||||
repo_id: Optional[str] = Param(
|
||||
help="Id of a repo on the HuggingFace Hub in the form of `user_name/repo_name`."
|
||||
)
|
||||
filename: Optional[str] = Param(
|
||||
help="A filename or glob pattern to match the model file in the repo."
|
||||
)
|
||||
chat_format: str = Param(
|
||||
help=(
|
||||
|
@ -28,7 +33,7 @@ class LlamaCppChat(ChatLLM):
|
|||
n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
|
||||
n_gpu_layers: Optional[int] = Param(
|
||||
0,
|
||||
help=("Number of layers to offload to GPU. If -1, all layers are offloaded"),
|
||||
help="Number of layers to offload to GPU. If -1, all layers are offloaded",
|
||||
)
|
||||
use_mmap: Optional[bool] = Param(
|
||||
True,
|
||||
|
@ -36,7 +41,7 @@ class LlamaCppChat(ChatLLM):
|
|||
)
|
||||
vocab_only: Optional[bool] = Param(
|
||||
False,
|
||||
help=("If True, only the vocabulary is loaded. This is useful for debugging."),
|
||||
help="If True, only the vocabulary is loaded. This is useful for debugging.",
|
||||
)
|
||||
|
||||
_role_mapper: dict[str, str] = {
|
||||
|
@ -57,8 +62,11 @@ class LlamaCppChat(ChatLLM):
|
|||
)
|
||||
|
||||
errors = []
|
||||
if not self.model_path:
|
||||
errors.append("- `model_path` is required to load the model")
|
||||
if not self.model_path and (not self.repo_id or not self.filename):
|
||||
errors.append(
|
||||
"- `model_path` or `repo_id` and `filename` are required to load the"
|
||||
" model"
|
||||
)
|
||||
|
||||
if not self.chat_format:
|
||||
errors.append(
|
||||
|
@ -69,15 +77,27 @@ class LlamaCppChat(ChatLLM):
|
|||
if errors:
|
||||
raise ValueError("\n".join(errors))
|
||||
|
||||
return Llama(
|
||||
model_path=cast(str, self.model_path),
|
||||
chat_format=self.chat_format,
|
||||
lora_base=self.lora_base,
|
||||
n_ctx=self.n_ctx,
|
||||
n_gpu_layers=self.n_gpu_layers,
|
||||
use_mmap=self.use_mmap,
|
||||
vocab_only=self.vocab_only,
|
||||
)
|
||||
if self.model_path:
|
||||
return Llama(
|
||||
model_path=cast(str, self.model_path),
|
||||
chat_format=self.chat_format,
|
||||
lora_base=self.lora_base,
|
||||
n_ctx=self.n_ctx,
|
||||
n_gpu_layers=self.n_gpu_layers,
|
||||
use_mmap=self.use_mmap,
|
||||
vocab_only=self.vocab_only,
|
||||
)
|
||||
else:
|
||||
return Llama.from_pretrained(
|
||||
repo_id=self.repo_id,
|
||||
filename=self.filename,
|
||||
chat_format=self.chat_format,
|
||||
lora_base=self.lora_base,
|
||||
n_ctx=self.n_ctx,
|
||||
n_gpu_layers=self.n_gpu_layers,
|
||||
use_mmap=self.use_mmap,
|
||||
vocab_only=self.vocab_only,
|
||||
)
|
||||
|
||||
def prepare_message(
|
||||
self, messages: str | BaseMessage | list[BaseMessage]
|
||||
|
|
|
@ -22,7 +22,7 @@ dependencies = [
|
|||
"theflow",
|
||||
"llama-index>=0.9.0,<0.10.0",
|
||||
"llama-hub",
|
||||
"gradio>=4.0.0,<=4.22.0",
|
||||
"gradio>=4.26.0",
|
||||
"openpyxl",
|
||||
"cookiecutter",
|
||||
"click",
|
||||
|
@ -34,6 +34,7 @@ dependencies = [
|
|||
"unstructured",
|
||||
"pypdf",
|
||||
"html2text",
|
||||
"fastembed",
|
||||
]
|
||||
readme = "README.md"
|
||||
license = { text = "MIT License" }
|
||||
|
|
|
@ -19,7 +19,6 @@ dependencies = [
|
|||
"python-decouple",
|
||||
"sqlalchemy",
|
||||
"sqlmodel",
|
||||
"fastembed",
|
||||
"tiktoken",
|
||||
"gradio>=4.26.0",
|
||||
"markdown",
|
||||
|
|
Loading…
Reference in New Issue
Block a user