Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class Llama:
def __init__(
self,
model_path: str,
clip_model_path: Optional[str] = None,
*,
# Model Params
n_gpu_layers: Union[int, Literal["auto", "all"]] = "auto",
Expand Down Expand Up @@ -608,6 +609,17 @@ def __init__(

if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr)

if clip_model_path is not None:
if self.chat_handler is not None and self.verbose:
print("Warning: Both `chat_handler` and `clip_model_path` are not null. Chat handler will be overwritten.", flush = True)

self.chat_handler = llama_chat_format.GenericMTMDChatHandler(
gguf_metadata = self.metadata,
clip_model_path = clip_model_path,
model_arch = None,
verbose = self.verbose
)

eos_token_id = self.token_eos()
bos_token_id = self.token_bos()
Expand Down
49 changes: 48 additions & 1 deletion llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2887,10 +2887,14 @@ def __init__(
raise ValueError(f"{self.log_prefix}(__init__): Clip model path does not exist: {clip_model_path}")

# Pre-compile Jinja template
if not hasattr(self, "chat_format") or self.chat_format is None:
self.chat_format = self.CHAT_FORMAT

self._chat_format_parser_tags = []
self.chat_template = ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
).from_string(self.CHAT_FORMAT)
).from_string(self.chat_format)

self._exit_stack = ExitStack()

Expand Down Expand Up @@ -3116,6 +3120,13 @@ def _process_mtmd_prompt(
tool_choice=tool_choice,
**getattr(self, 'extra_template_arguments', {})
)

for tag in self._chat_format_parser_tags:
if tag not in text:
continue

text = text.replace(tag, media_marker)

# Replace image_url by media_marker in text
for item in media_items:
text = text.replace(item["url"], media_marker)
Expand Down Expand Up @@ -3827,6 +3838,42 @@ def from_pretrained(
**kwargs,
)

class GenericMTMDChatHandler(MTMDChatHandler):
def __init__(
self,
gguf_metadata: Dict[str, Any],
clip_model_path: str,
model_arch: Optional[str] = None,
verbose: bool = True,
**kwargs
) -> None:
self.model_metadata = gguf_metadata

self.chat_format = self.model_metadata.get("tokenizer.chat_template", None)
self.arch = self.model_metadata.get("general.architecture", None) if model_arch is None else model_arch

if verbose:
print(f"Got chat template from model:\n```jinja\n{self.chat_format}\n```", flush = True)

if self.arch is None:
if verbose:
print("Unknown model architecture. Will use general/most-common tags.")

self.arch = "unknown"

if self.chat_format is None:
raise ValueError("Failed to get model chat template automatically.")

super().__init__(clip_model_path = clip_model_path, verbose = verbose, **kwargs)

if self.arch in ["unknown", "qwen3vl", "qwen35moe", "qwen35"]:
self._chat_format_parser_tags += ["<|image_pad|>", "<|audio_pad|>", "<|video_pad|>"]
elif self.arch in ["gemma4"]:
self._chat_format_parser_tags += ["<|image|>", "<|audio|>", "<|video|>"]
elif self.arch in ["mistral3", "mistral4", "deepseek2"]:
self._chat_format_parser_tags += ["[IMG]"]
elif verbose:
print("Warning: Could not determine chat format parser tags.", flush = True)

class Llava15ChatHandler(MTMDChatHandler):
CHAT_FORMAT = (
Expand Down
Loading