diff --git a/alto/kernels/fp4/nvfp4/nvfp_quantization.py b/alto/kernels/fp4/nvfp4/nvfp_quantization.py index b0597b7..af999c8 100644 --- a/alto/kernels/fp4/nvfp4/nvfp_quantization.py +++ b/alto/kernels/fp4/nvfp4/nvfp_quantization.py @@ -17,10 +17,19 @@ BLOCK_SIZE_DEFAULT = 16 +# Host-facing numeric constants -- kept as plain Python floats so eager/host +# code can use them directly (e.g. ``amax / (F8E4M3_MAX * F4_E2M1_MAX)``). F4_E2M1_MAX = 6.0 F8E4M3_MAX = 448.0 E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny +# ``@triton.jit`` kernels cannot read plain (non-constexpr) module globals on +# triton >=3.6 (raises NameError), so expose ``tl.constexpr`` mirrors for +# kernel-side use only. Keep these in sync with the host-facing values above. +_F4_E2M1_MAX = tl.constexpr(F4_E2M1_MAX) +_F8E4M3_MAX = tl.constexpr(F8E4M3_MAX) +_E4M3_EPS = tl.constexpr(E4M3_EPS) + # Naming convention for the NVFP4 scale hierarchy: # inner_scale -- per-block scale stored alongside the packed FP4 data # (NVFP4 spec ``s_block``). Value lives on the E4M3 grid @@ -232,13 +241,13 @@ def _calculate_nvfp4_scales( if USE_OUTER_SCALE: outer_scale = tl.load(outer_scale_ptr) - inner_scale_raw = max_abs / outer_scale / F4_E2M1_MAX - inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, E4M3_EPS), F8E4M3_MAX) + inner_scale_raw = max_abs / outer_scale / _F4_E2M1_MAX + inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, _E4M3_EPS), _F8E4M3_MAX) inner_scale = inner_scale_raw.to(tl.float8e4nv).to(tl.float32) quant_scale = inner_scale * outer_scale else: - inner_scale_raw = max_abs / F4_E2M1_MAX - inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, E4M3_EPS), F8E4M3_MAX) + inner_scale_raw = max_abs / _F4_E2M1_MAX + inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, _E4M3_EPS), _F8E4M3_MAX) inner_scale = inner_scale_raw.to(tl.float8e4nv).to(tl.float32) quant_scale = inner_scale