diff --git a/src/maxdiffusion/configs/ltx2_video.yml b/src/maxdiffusion/configs/ltx2_video.yml index 9b676d4cc..1d57b3467 100644 --- a/src/maxdiffusion/configs/ltx2_video.yml +++ b/src/maxdiffusion/configs/ltx2_video.yml @@ -89,6 +89,8 @@ enable_profiler: False enable_ml_diagnostics: True profiler_gcs_path: "gs://mehdy/profiler/ml_diagnostics" enable_ondemand_xprof: True +skip_first_n_steps_for_profiler: 0 +profiler_steps: 5 replicate_vae: False diff --git a/src/maxdiffusion/generate_ltx2.py b/src/maxdiffusion/generate_ltx2.py index d4a356d9d..93f604352 100644 --- a/src/maxdiffusion/generate_ltx2.py +++ b/src/maxdiffusion/generate_ltx2.py @@ -116,7 +116,9 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log("Could not retrieve Git commit hash.") checkpoint_loader = LTX2Checkpointer(config=config) + load_time = 0.0 if pipeline is None: + t0_load = time.perf_counter() # Use the config flag to determine if the upsampler should be loaded run_latent_upsampler = getattr(config, "run_latent_upsampler", False) pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler) @@ -145,6 +147,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): scan_layers=config.scan_layers, dtype=config.weights_dtype, ) + load_time = time.perf_counter() - t0_load pipeline.enable_vae_slicing() pipeline.enable_vae_tiling() @@ -162,12 +165,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" ) - out = call_pipeline(config, pipeline, prompt, negative_prompt) - - # out should have .frames and .audio - videos = out.frames if hasattr(out, "frames") else out[0] - audios = out.audio if hasattr(out, "audio") else None - max_logging.log("===================== Model details =======================") max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}") max_logging.log(f"model path: {config.pretrained_model_name_or_path}") @@ -179,11 +176,48 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}") max_logging.log("============================================================") + original_enable_profiler = config.get_keys().get("enable_profiler", False) + original_enable_mld = config.get_keys().get("enable_ml_diagnostics", False) + original_num_steps = config.get_keys().get("num_inference_steps", 40) + + # --------------------------------------------------------- + # Run 1: Warmup Compilation (Original steps, NO profiling) + # --------------------------------------------------------- + config.get_keys()["enable_profiler"] = False + config.get_keys()["enable_ml_diagnostics"] = False + + max_logging.log(f"šŸš€ Starting warmup compilation pass ({original_num_steps} steps)...") + _ = call_pipeline(config, pipeline, prompt, negative_prompt) + compile_time = time.perf_counter() - s0 max_logging.log(f"compile_time: {compile_time}") if writer and jax.process_index() == 0: writer.add_scalar("inference/compile_time", compile_time, global_step=0) + # --------------------------------------------------------- + # Run 2: Actual Generation (Original steps, NO profiling) + # --------------------------------------------------------- + + s0 = time.perf_counter() + max_logging.log("šŸš€ Starting actual full-length generation pass...") + out = call_pipeline(config, pipeline, prompt, negative_prompt) + generation_time = time.perf_counter() - s0 + max_logging.log(f"generation_time: {generation_time}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time", generation_time, global_step=0) + num_devices = jax.device_count() + num_videos = num_devices * config.per_device_batch_size + if num_videos > 0: + generation_time_per_video = generation_time / num_videos + writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) + max_logging.log(f"generation time per video: {generation_time_per_video}") + else: + max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") + + # out should have .frames and .audio + videos = out.frames if hasattr(out, "frames") else out[0] + audios = out.audio if hasattr(out, "audio") else None + saved_video_path = [] audio_sample_rate = ( getattr(pipeline.vocoder.config, "output_sampling_rate", 24000) if hasattr(pipeline, "vocoder") else 24000 @@ -210,29 +244,68 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): if config.output_dir.startswith("gs://"): upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path) - s0 = time.perf_counter() - call_pipeline(config, pipeline, prompt, negative_prompt) - generation_time = time.perf_counter() - s0 - max_logging.log(f"generation_time: {generation_time}") - if writer and jax.process_index() == 0: - writer.add_scalar("inference/generation_time", generation_time, global_step=0) - num_devices = jax.device_count() - num_videos = num_devices * config.per_device_batch_size - if num_videos > 0: - generation_time_per_video = generation_time / num_videos - writer.add_scalar("inference/generation_time_per_video", generation_time_per_video, global_step=0) - max_logging.log(f"generation time per video: {generation_time_per_video}") - else: - max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") + timing_str = ( + f"\n{'=' * 50}\n" + f" TIMING SUMMARY\n" + f"{'=' * 50}\n" + f" Load (checkpoint): {load_time:>7.1f}s\n" + f" Compile: {compile_time:>7.1f}s\n" + f" {'─' * 40}\n" + f" Inference: {generation_time:>7.1f}s\n" + ) + if hasattr(out, "timings") and out.timings: + timing_str += ( + f" Text Encoding: {out.timings.get('Text Encoding', 0.0):>7.1f}s\n" + f" Preparation: {out.timings.get('Preparation', 0.0):>7.1f}s\n" + f" Connectors: {out.timings.get('Connectors', 0.0):>7.1f}s\n" + f" Denoising: {out.timings.get('Denoising', 0.0):>7.1f}s\n" + ) + if out.timings.get("Latent Upsampler", 0.0) > 0.0: + timing_str += f" Latent Upsampler: {out.timings.get('Latent Upsampler', 0.0):>7.1f}s\n" + timing_str += ( + f" Latent Processing: {out.timings.get('Latent Processing', 0.0):>7.1f}s\n" + f" Video VAE: {out.timings.get('Video VAE', 0.0):>7.1f}s\n" + f" Video Post: {out.timings.get('Video Post', 0.0):>7.1f}s\n" + f" Audio VAE: {out.timings.get('Audio VAE', 0.0):>7.1f}s\n" + f" Vocoder: {out.timings.get('Vocoder', 0.0):>7.1f}s\n" + ) + timing_str += f"{'=' * 50}" + max_logging.log(timing_str) - s0 = time.perf_counter() - if max_utils.profiler_enabled(config): - with max_utils.Profiler(config): - call_pipeline(config, pipeline, prompt, negative_prompt) - generation_time_with_profiler = time.perf_counter() - s0 - max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") - if writer and jax.process_index() == 0: - writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) + # Free memory before profiling + del out + del videos + del audios + + # --------------------------------------------------------- + # Run 3: Profiling Run (Only if profiling was originally enabled) + # --------------------------------------------------------- + if original_enable_profiler or original_enable_mld: + skip_first_n_steps_for_profiler = config.get_keys().get("skip_first_n_steps_for_profiler", 0) + if skip_first_n_steps_for_profiler != 0: + max_logging.log( + "\nāš ļø WARNING: 'skip_first_n_steps_for_profiler' is ignored because 'scan_diffusion_loop' is enabled! The profiler will capture all steps in this profile run.\n" + ) + + profiling_steps = config.get_keys().get("profiler_steps", 5) + + config.get_keys()["enable_profiler"] = False + config.get_keys()["enable_ml_diagnostics"] = False + config.get_keys()["num_inference_steps"] = profiling_steps + + max_logging.log(f"šŸš€ Warmup for profiling pass ({profiling_steps} steps)...") + _ = call_pipeline(config, pipeline, prompt, negative_prompt) + + config.get_keys()["enable_profiler"] = original_enable_profiler + config.get_keys()["enable_ml_diagnostics"] = original_enable_mld + + max_logging.log(f"šŸš€ Starting Profiling run ({profiling_steps} steps)...") + profiler = max_utils.Profiler(config, session_name=f"denoise_profile_{profiling_steps}_steps") + profiler.start() + + _ = call_pipeline(config, pipeline, prompt, negative_prompt) + + profiler.stop() return saved_video_path diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 9cc1c970e..81b3cdac2 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -17,6 +17,7 @@ from typing import Optional, Any, List, Union from functools import partial +import time import numpy as np import torch import jax @@ -60,6 +61,7 @@ class LTX2PipelineOutput: frames: jax.Array audio: Optional[jax.Array] = None + timings: Optional[Any] = None def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): @@ -1204,12 +1206,14 @@ def __call__( output_type: str = "pil", return_dict: bool = True, ): + t0_init = time.perf_counter() # 1. Check inputs self.check_inputs( prompt, height, width, prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask ) # 2. Encode inputs (Text) + t0_encode = time.perf_counter() prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( prompt, negative_prompt, @@ -1222,6 +1226,10 @@ def __call__( max_sequence_length=max_sequence_length, dtype=dtype, ) + encode_time = time.perf_counter() - t0_encode + max_logging.log(f"Text Encoding time (Gemma-3 on CPU): {encode_time:.2f}s") + + t0_setup = time.perf_counter() # 3. Prepare latents batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0] @@ -1339,9 +1347,17 @@ def __call__( with context_manager, axis_rules_context: connectors_graphdef, connectors_state = nnx.split(self.connectors) - video_embeds, audio_embeds, new_attention_mask = self._run_connectors( - connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_) - ) + setup_time = time.perf_counter() - t0_setup + max_logging.log(f"Preparation/Setup time: {setup_time:.2f}s") + + t0_connectors = time.perf_counter() + with jax.named_scope("connectors_pass"): + video_embeds, audio_embeds, new_attention_mask = self._run_connectors( + connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_) + ) + video_embeds = video_embeds.block_until_ready() + connectors_time = time.perf_counter() - t0_connectors + max_logging.log(f"Connectors pass time: {connectors_time:.2f}s") video_embeds_sharded = video_embeds audio_embeds_sharded = audio_embeds @@ -1354,6 +1370,7 @@ def __call__( timesteps_jax = jnp.array(timesteps, dtype=jnp.float32) + t0_denoise = time.perf_counter() scan_diffusion_loop = getattr(self.config, "scan_diffusion_loop", True) if scan_diffusion_loop: @@ -1438,6 +1455,12 @@ def __call__( latents_jax = latents_step audio_latents_jax = audio_latents_step + jax.block_until_ready(latents_jax) + denoise_time = time.perf_counter() - t0_denoise + max_logging.log(f"Denoising steps time: {denoise_time:.2f}s") + + t0_latent_processing = time.perf_counter() + # 8. Decode Latents if guidance_scale > 1.0: latents_jax = latents_jax[batch_size:] @@ -1459,31 +1482,42 @@ def __call__( # VAE expects channels last (B, T, H, W, C) but unpack returns (B, C, T, H, W) latents = latents.transpose(0, 2, 3, 4, 1) + latent_processing_time = time.perf_counter() - t0_latent_processing + # ======================================================================= # LATENT UPSAMPLER # ======================================================================= if getattr(self.config, "run_latent_upsampler", False) and self.latent_upsampler is not None: max_logging.log("šŸš€ Running Latent Upsampler pass...") - if self.latent_upsampler_params is not None: - nnx.update(self.latent_upsampler, self.latent_upsampler_params) - self.latent_upsampler_params = None + upsampler_t0 = time.perf_counter() + with jax.named_scope("upsampler_pass"): + if self.latent_upsampler_params is not None: + nnx.update(self.latent_upsampler, self.latent_upsampler_params) + self.latent_upsampler_params = None - graphdef, state = nnx.split(self.latent_upsampler) + graphdef, state = nnx.split(self.latent_upsampler) - latents_upsampled = self._run_upsampler(graphdef, state, latents) + latents_upsampled = self._run_upsampler(graphdef, state, latents) - adain_factor = getattr(self.config, "upsampler_adain_factor", 0.0) - if adain_factor > 0.0: - latents = adain_filter_latent(latents_upsampled, latents, adain_factor) - else: - latents = latents_upsampled + adain_factor = getattr(self.config, "upsampler_adain_factor", 0.0) + if adain_factor > 0.0: + latents = adain_filter_latent(latents_upsampled, latents, adain_factor) + else: + latents = latents_upsampled + + tone_map_compression = getattr(self.config, "upsampler_tone_map_compression_ratio", 0.0) + if tone_map_compression > 0.0: + latents = tone_map_latents(latents, tone_map_compression) - tone_map_compression = getattr(self.config, "upsampler_tone_map_compression_ratio", 0.0) - if tone_map_compression > 0.0: - latents = tone_map_latents(latents, tone_map_compression) + jax.block_until_ready(latents) + upsampler_time = time.perf_counter() - upsampler_t0 + + max_logging.log(f"Latent Upsampler time: {upsampler_time:.2f}s") # ======================================================================= + t0_latent_processing = time.perf_counter() + # Denormalize and Unpack Audio (Order important: Denorm THEN Unpack) audio_latents = self._denormalize_audio_latents( audio_latents_jax, self.audio_vae.latents_mean.value, self.audio_vae.latents_std.value @@ -1500,8 +1534,17 @@ def __call__( if audio_latents.ndim == 4: audio_latents = audio_latents.transpose(0, 2, 3, 1) + timings = { + "Text Encoding": encode_time, + "Preparation": setup_time, + "Connectors": connectors_time, + "Denoising": denoise_time, + "Latent Processing": latent_processing_time, + "Latent Upsampler": upsampler_time if "upsampler_time" in locals() else 0.0, + } + if output_type == "latent": - return LTX2PipelineOutput(frames=latents, audio=audio_latents) + return LTX2PipelineOutput(frames=latents, audio=audio_latents, timings=timings) # Force latents and VAE weights to be fully replicated using with_sharding_constraint, this speeds up single video latency ~3x try: @@ -1518,43 +1561,70 @@ def __call__( except Exception as e: max_logging.log(f"[Tuning] Failed to apply sharding constraint: {e}") - if getattr(self.vae.config, "timestep_conditioning", False): - noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype) + latent_processing_time += time.perf_counter() - t0_latent_processing + timings["Latent Processing"] = latent_processing_time - if not isinstance(decode_timestep, list): - decode_timestep = [decode_timestep] * batch_size - if decode_noise_scale is None: - decode_noise_scale = decode_timestep - elif not isinstance(decode_noise_scale, list): - decode_noise_scale = [decode_noise_scale] * batch_size + t0_video_vae = time.perf_counter() + with jax.named_scope("video_vae_decode"): + if getattr(self.vae.config, "timestep_conditioning", False): + noise = jax.random.normal(generator, latents.shape, dtype=latents.dtype) - timestep = jnp.array(decode_timestep, dtype=latents.dtype) - decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None] + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * batch_size + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * batch_size - latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + timestep = jnp.array(decode_timestep, dtype=latents.dtype) + decode_noise_scale = jnp.array(decode_noise_scale, dtype=latents.dtype)[:, None, None, None, None] + + latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise + + latents = latents.astype(self.vae.dtype) + video = self.vae.decode(latents, temb=timestep, return_dict=False)[0] + else: + latents = latents.astype(self.vae.dtype) + video = self.vae.decode(latents, return_dict=False)[0] + + video = video.block_until_ready() + video_vae_time = time.perf_counter() - t0_video_vae + max_logging.log(f"Video VAE decode time: {video_vae_time:.2f}s") - latents = latents.astype(self.vae.dtype) - video = self.vae.decode(latents, temb=timestep, return_dict=False)[0] - else: - latents = latents.astype(self.vae.dtype) - video = self.vae.decode(latents, return_dict=False)[0] # Post-process video (converts to numpy/PIL) # VAE outputs (B, T, H, W, C), but video processor expects (B, C, T, H, W) + t0_video_post = time.perf_counter() video_np = np.array(video).transpose(0, 4, 1, 2, 3) video = self.video_processor.postprocess_video(torch.from_numpy(video_np), output_type=output_type) + video_post_time = time.perf_counter() - t0_video_post + max_logging.log(f"Video Post-processing time (numpy+PIL): {video_post_time:.2f}s") # Decode Audio - audio_latents = audio_latents.astype(self.audio_vae.dtype) - generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + t0_audio_vae = time.perf_counter() + with jax.named_scope("audio_vae_decode"): + audio_latents = audio_latents.astype(self.audio_vae.dtype) + generated_mel_spectrograms = self.audio_vae.decode(audio_latents, return_dict=False)[0] + generated_mel_spectrograms = generated_mel_spectrograms.block_until_ready() + audio_vae_time = time.perf_counter() - t0_audio_vae + max_logging.log(f"Audio VAE decode time: {audio_vae_time:.2f}s") # Audio VAE outputs (B, T, F, C), Vocoder expects (B, Channels, Time, MelBins) - generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2) - audio = self.vocoder(generated_mel_spectrograms) + t0_vocoder = time.perf_counter() + with jax.named_scope("vocoder_pass"): + generated_mel_spectrograms = generated_mel_spectrograms.transpose(0, 3, 1, 2) + audio = self.vocoder(generated_mel_spectrograms) # Convert audio to numpy audio = np.array(audio) + vocoder_time = time.perf_counter() - t0_vocoder + max_logging.log(f"Vocoder & Audio numpy time: {vocoder_time:.2f}s") + + timings["Video VAE"] = video_vae_time + timings["Video Post"] = video_post_time + timings["Audio VAE"] = audio_vae_time + timings["Vocoder"] = vocoder_time - return LTX2PipelineOutput(frames=video, audio=audio) + return LTX2PipelineOutput(frames=video, audio=audio, timings=timings) @partial( @@ -1666,51 +1736,55 @@ def scan_body(carry, t, model): # Expand timestep to batch size t_expanded = jnp.expand_dims(t, 0).repeat(latents.shape[0]) - noise_pred, noise_pred_audio = model( - hidden_states=latents_sharded, - encoder_hidden_states=video_embeds_sharded, - timestep=t_expanded, - encoder_attention_mask=new_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - audio_hidden_states=audio_latents_sharded, - audio_encoder_hidden_states=audio_embeds_sharded, - audio_encoder_attention_mask=new_attention_mask, - fps=fps, - audio_num_frames=audio_num_frames, - return_dict=False, - ) - - if guidance_scale > 1.0: - noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # Audio guidance - ( - noise_pred_audio_uncond, - noise_pred_audio_text, - ) = jnp.split(noise_pred_audio, 2, axis=0) - noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) - - latents_step = latents[batch_size:] - audio_latents_step = audio_latents[batch_size:] - else: - latents_step = latents - audio_latents_step = audio_latents - - # Step scheduler - latents_step, _ = scheduler_step(s_state, noise_pred, t, latents_step, return_dict=False) - latents_step = latents_step.astype(latents.dtype) - - audio_latents_step, _ = scheduler_step(s_state, noise_pred_audio, t, audio_latents_step, return_dict=False) - audio_latents_step = audio_latents_step.astype(audio_latents.dtype) + with jax.named_scope("transformer_forward_pass"): + noise_pred, noise_pred_audio = model( + hidden_states=latents_sharded, + encoder_hidden_states=video_embeds_sharded, + timestep=t_expanded, + encoder_attention_mask=new_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + audio_hidden_states=audio_latents_sharded, + audio_encoder_hidden_states=audio_embeds_sharded, + audio_encoder_attention_mask=new_attention_mask, + fps=fps, + audio_num_frames=audio_num_frames, + return_dict=False, + ) - if guidance_scale > 1.0: - latents_next = jnp.concatenate([latents_step] * 2, axis=0) - audio_latents_next = jnp.concatenate([audio_latents_step] * 2, axis=0) - else: - latents_next = latents_step - audio_latents_next = audio_latents_step + with jax.named_scope("classifier_free_guidance"): + if guidance_scale > 1.0: + noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + # Audio guidance + ( + noise_pred_audio_uncond, + noise_pred_audio_text, + ) = jnp.split(noise_pred_audio, 2, axis=0) + noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) + + latents_step = latents[batch_size:] + audio_latents_step = audio_latents[batch_size:] + else: + latents_step = latents + audio_latents_step = audio_latents + + with jax.named_scope("scheduler_step"): + # Step scheduler + latents_step, _ = scheduler_step(s_state, noise_pred, t, latents_step, return_dict=False) + latents_step = latents_step.astype(latents.dtype) + + audio_latents_step, _ = scheduler_step(s_state, noise_pred_audio, t, audio_latents_step, return_dict=False) + audio_latents_step = audio_latents_step.astype(audio_latents.dtype) + + with jax.named_scope("latent_concatenation"): + if guidance_scale > 1.0: + latents_next = jnp.concatenate([latents_step] * 2, axis=0) + audio_latents_next = jnp.concatenate([audio_latents_step] * 2, axis=0) + else: + latents_next = latents_step + audio_latents_next = audio_latents_step new_carry = (latents_next, audio_latents_next, s_state) return new_carry, None