Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions alto/kernels/fp4/nvfp4/nvfp_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@


BLOCK_SIZE_DEFAULT = 16
# Host-facing numeric constants -- kept as plain Python floats so eager/host
# code can use them directly (e.g. ``amax / (F8E4M3_MAX * F4_E2M1_MAX)``).
F4_E2M1_MAX = 6.0
F8E4M3_MAX = 448.0
E4M3_EPS = torch.finfo(torch.float8_e4m3fn).tiny

# ``@triton.jit`` kernels cannot read plain (non-constexpr) module globals on
# triton >=3.6 (raises NameError), so expose ``tl.constexpr`` mirrors for
# kernel-side use only. Keep these in sync with the host-facing values above.
_F4_E2M1_MAX = tl.constexpr(F4_E2M1_MAX)
_F8E4M3_MAX = tl.constexpr(F8E4M3_MAX)
_E4M3_EPS = tl.constexpr(E4M3_EPS)

# Naming convention for the NVFP4 scale hierarchy:
# inner_scale -- per-block scale stored alongside the packed FP4 data
# (NVFP4 spec ``s_block``). Value lives on the E4M3 grid
Expand Down Expand Up @@ -232,13 +241,13 @@ def _calculate_nvfp4_scales(

if USE_OUTER_SCALE:
outer_scale = tl.load(outer_scale_ptr)
inner_scale_raw = max_abs / outer_scale / F4_E2M1_MAX
inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, E4M3_EPS), F8E4M3_MAX)
inner_scale_raw = max_abs / outer_scale / _F4_E2M1_MAX
inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, _E4M3_EPS), _F8E4M3_MAX)
inner_scale = inner_scale_raw.to(tl.float8e4nv).to(tl.float32)
quant_scale = inner_scale * outer_scale
else:
inner_scale_raw = max_abs / F4_E2M1_MAX
inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, E4M3_EPS), F8E4M3_MAX)
inner_scale_raw = max_abs / _F4_E2M1_MAX
inner_scale_raw = tl.minimum(tl.maximum(inner_scale_raw, _E4M3_EPS), _F8E4M3_MAX)
inner_scale = inner_scale_raw.to(tl.float8e4nv).to(tl.float32)
quant_scale = inner_scale

Expand Down