diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 1241f81e2..848706a90 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -85,6 +85,7 @@ class Llama: def __init__( self, model_path: str, + clip_model_path: Optional[str] = None, *, # Model Params n_gpu_layers: Union[int, Literal["auto", "all"]] = "auto", @@ -608,6 +609,17 @@ def __init__( if self.verbose: print(f"Model metadata: {self.metadata}", file=sys.stderr) + + if clip_model_path is not None: + if self.chat_handler is not None and self.verbose: + print("Warning: Both `chat_handler` and `clip_model_path` are not null. Chat handler will be overwritten.", flush = True) + + self.chat_handler = llama_chat_format.GenericMTMDChatHandler( + gguf_metadata = self.metadata, + clip_model_path = clip_model_path, + model_arch = None, + verbose = self.verbose + ) eos_token_id = self.token_eos() bos_token_id = self.token_bos() diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index a0d8d25db..ab5e438d3 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2887,10 +2887,14 @@ def __init__( raise ValueError(f"{self.log_prefix}(__init__): Clip model path does not exist: {clip_model_path}") # Pre-compile Jinja template + if not hasattr(self, "chat_format") or self.chat_format is None: + self.chat_format = self.CHAT_FORMAT + + self._chat_format_parser_tags = [] self.chat_template = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, - ).from_string(self.CHAT_FORMAT) + ).from_string(self.chat_format) self._exit_stack = ExitStack() @@ -3116,6 +3120,13 @@ def _process_mtmd_prompt( tool_choice=tool_choice, **getattr(self, 'extra_template_arguments', {}) ) + + for tag in self._chat_format_parser_tags: + if tag not in text: + continue + + text = text.replace(tag, media_marker) + # Replace image_url by media_marker in text for item in media_items: text = text.replace(item["url"], media_marker) @@ -3827,6 +3838,42 @@ def from_pretrained( **kwargs, ) +class GenericMTMDChatHandler(MTMDChatHandler): + def __init__( + self, + gguf_metadata: Dict[str, Any], + clip_model_path: str, + model_arch: Optional[str] = None, + verbose: bool = True, + **kwargs + ) -> None: + self.model_metadata = gguf_metadata + + self.chat_format = self.model_metadata.get("tokenizer.chat_template", None) + self.arch = self.model_metadata.get("general.architecture", None) if model_arch is None else model_arch + + if verbose: + print(f"Got chat template from model:\n```jinja\n{self.chat_format}\n```", flush = True) + + if self.arch is None: + if verbose: + print("Unknown model architecture. Will use general/most-common tags.") + + self.arch = "unknown" + + if self.chat_format is None: + raise ValueError("Failed to get model chat template automatically.") + + super().__init__(clip_model_path = clip_model_path, verbose = verbose, **kwargs) + + if self.arch in ["unknown", "qwen3vl", "qwen35moe", "qwen35"]: + self._chat_format_parser_tags += ["<|image_pad|>", "<|audio_pad|>", "<|video_pad|>"] + elif self.arch in ["gemma4"]: + self._chat_format_parser_tags += ["<|image|>", "<|audio|>", "<|video|>"] + elif self.arch in ["mistral3", "mistral4", "deepseek2"]: + self._chat_format_parser_tags += ["[IMG]"] + elif verbose: + print("Warning: Could not determine chat format parser tags.", flush = True) class Llava15ChatHandler(MTMDChatHandler): CHAT_FORMAT = (