allow LlamaCppChat to auto download model from hf hub (#29)
This commit is contained in:
parent
917fb0a082
commit
4022af7e9b
|
@ -12,9 +12,14 @@ if TYPE_CHECKING:
|
||||||
class LlamaCppChat(ChatLLM):
|
class LlamaCppChat(ChatLLM):
|
||||||
"""Wrapper around the llama-cpp-python's Llama model"""
|
"""Wrapper around the llama-cpp-python's Llama model"""
|
||||||
|
|
||||||
model_path: str = Param(
|
model_path: Optional[str] = Param(
|
||||||
help="Path to the model file. This is required to load the model.",
|
help="Path to the model file. This is required to load the model.",
|
||||||
required=True,
|
)
|
||||||
|
repo_id: Optional[str] = Param(
|
||||||
|
help="Id of a repo on the HuggingFace Hub in the form of `user_name/repo_name`."
|
||||||
|
)
|
||||||
|
filename: Optional[str] = Param(
|
||||||
|
help="A filename or glob pattern to match the model file in the repo."
|
||||||
)
|
)
|
||||||
chat_format: str = Param(
|
chat_format: str = Param(
|
||||||
help=(
|
help=(
|
||||||
|
@ -28,7 +33,7 @@ class LlamaCppChat(ChatLLM):
|
||||||
n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
|
n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
|
||||||
n_gpu_layers: Optional[int] = Param(
|
n_gpu_layers: Optional[int] = Param(
|
||||||
0,
|
0,
|
||||||
help=("Number of layers to offload to GPU. If -1, all layers are offloaded"),
|
help="Number of layers to offload to GPU. If -1, all layers are offloaded",
|
||||||
)
|
)
|
||||||
use_mmap: Optional[bool] = Param(
|
use_mmap: Optional[bool] = Param(
|
||||||
True,
|
True,
|
||||||
|
@ -36,7 +41,7 @@ class LlamaCppChat(ChatLLM):
|
||||||
)
|
)
|
||||||
vocab_only: Optional[bool] = Param(
|
vocab_only: Optional[bool] = Param(
|
||||||
False,
|
False,
|
||||||
help=("If True, only the vocabulary is loaded. This is useful for debugging."),
|
help="If True, only the vocabulary is loaded. This is useful for debugging.",
|
||||||
)
|
)
|
||||||
|
|
||||||
_role_mapper: dict[str, str] = {
|
_role_mapper: dict[str, str] = {
|
||||||
|
@ -57,8 +62,11 @@ class LlamaCppChat(ChatLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
errors = []
|
errors = []
|
||||||
if not self.model_path:
|
if not self.model_path and (not self.repo_id or not self.filename):
|
||||||
errors.append("- `model_path` is required to load the model")
|
errors.append(
|
||||||
|
"- `model_path` or `repo_id` and `filename` are required to load the"
|
||||||
|
" model"
|
||||||
|
)
|
||||||
|
|
||||||
if not self.chat_format:
|
if not self.chat_format:
|
||||||
errors.append(
|
errors.append(
|
||||||
|
@ -69,15 +77,27 @@ class LlamaCppChat(ChatLLM):
|
||||||
if errors:
|
if errors:
|
||||||
raise ValueError("\n".join(errors))
|
raise ValueError("\n".join(errors))
|
||||||
|
|
||||||
return Llama(
|
if self.model_path:
|
||||||
model_path=cast(str, self.model_path),
|
return Llama(
|
||||||
chat_format=self.chat_format,
|
model_path=cast(str, self.model_path),
|
||||||
lora_base=self.lora_base,
|
chat_format=self.chat_format,
|
||||||
n_ctx=self.n_ctx,
|
lora_base=self.lora_base,
|
||||||
n_gpu_layers=self.n_gpu_layers,
|
n_ctx=self.n_ctx,
|
||||||
use_mmap=self.use_mmap,
|
n_gpu_layers=self.n_gpu_layers,
|
||||||
vocab_only=self.vocab_only,
|
use_mmap=self.use_mmap,
|
||||||
)
|
vocab_only=self.vocab_only,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Llama.from_pretrained(
|
||||||
|
repo_id=self.repo_id,
|
||||||
|
filename=self.filename,
|
||||||
|
chat_format=self.chat_format,
|
||||||
|
lora_base=self.lora_base,
|
||||||
|
n_ctx=self.n_ctx,
|
||||||
|
n_gpu_layers=self.n_gpu_layers,
|
||||||
|
use_mmap=self.use_mmap,
|
||||||
|
vocab_only=self.vocab_only,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_message(
|
def prepare_message(
|
||||||
self, messages: str | BaseMessage | list[BaseMessage]
|
self, messages: str | BaseMessage | list[BaseMessage]
|
||||||
|
|
|
@ -22,7 +22,7 @@ dependencies = [
|
||||||
"theflow",
|
"theflow",
|
||||||
"llama-index>=0.9.0,<0.10.0",
|
"llama-index>=0.9.0,<0.10.0",
|
||||||
"llama-hub",
|
"llama-hub",
|
||||||
"gradio>=4.0.0,<=4.22.0",
|
"gradio>=4.26.0",
|
||||||
"openpyxl",
|
"openpyxl",
|
||||||
"cookiecutter",
|
"cookiecutter",
|
||||||
"click",
|
"click",
|
||||||
|
@ -34,6 +34,7 @@ dependencies = [
|
||||||
"unstructured",
|
"unstructured",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"html2text",
|
"html2text",
|
||||||
|
"fastembed",
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "MIT License" }
|
license = { text = "MIT License" }
|
||||||
|
|
|
@ -19,7 +19,6 @@ dependencies = [
|
||||||
"python-decouple",
|
"python-decouple",
|
||||||
"sqlalchemy",
|
"sqlalchemy",
|
||||||
"sqlmodel",
|
"sqlmodel",
|
||||||
"fastembed",
|
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"gradio>=4.26.0",
|
"gradio>=4.26.0",
|
||||||
"markdown",
|
"markdown",
|
||||||
|
|
Loading…
Reference in New Issue
Block a user