Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924
Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924ksivaman wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci |
Greptile SummaryThis PR fixes grouped MXFP8 swizzle when per-expert rows are not a multiple of 128. The core change introduces a "compact" vs "per-tensor-padded" layout distinction: the quantize kernel writes a compact buffer (no padding between experts), while the swizzle output must be padded to Confidence Score: 5/5Safe to merge; no P0/P1 issues found; logic is correct and well-tested across edge cases. All findings are P2 or below. The compact-layout detection, OOB-load prevention, and output buffer allocation are logically correct and consistent between swizzle.cu and swizzle.cpp. The test suite covers aligned, unaligned, and mixed shapes including the originally-failing workload shape. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["maybe_swizzle_grouped_tensor (swizzle.cpp)"]
A -->|allocate output| B["compute_padded_grouped_scale_shape\nnum_tensors × roundup(M,128) × roundup(⌈K/32⌉,4)"]
A --> C["nvte_swizzle_grouped_scaling_factors (swizzle.cu)"]
C --> D{Detect input layout}
D -->|numel == num_tensors × padded_scale_elems| E["input_is_compact = false\ninput_stride = padded_m × padded_k"]
D -->|numel == compact_total_scale_elems| F["input_is_compact = true\ninput_stride = m × padded_k (rowwise)\nor ⌈M/32⌉ × padded_m (colwise)"]
D -->|mismatch| G[NVTE_ERROR]
E --> H[dispatch_swizzle_*_kernel_impl]
F --> H
H -->|IS_PADDED_M=true, row ≥ original_M| I[Zero register, skip __ldg]
H -->|IS_PADDED_K=true, k_coord ≥ original_K| J[Zero register, skip __ldg]
H -->|in-bounds| K["__ldg + per-byte boundary zeroing"]
I --> L[Output: per-tensor padded layout\noutput_stride = padded_m × padded_k]
J --> L
K --> L
Reviews (3): Last reviewed commit: "Add test for swizzle + padding fusion" | Re-trigger Greptile |
| const auto logical_shape_nvte = input.logical_shape(); | ||
| NVTE_CHECK(logical_shape_nvte.ndim >= 2, | ||
| "Grouped GEMM swizzle expects logical_shape with ndim >= 2."); | ||
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; |
There was a problem hiding this comment.
Silent truncation when
logical_shape_nvte.data[0] is not divisible by num_tensors
per_tensor_first_dim is computed with plain integer division. If logical_shape_nvte.data[0] is not an exact multiple of num_tensors (e.g. due to a caller bug or unexpected grouped layout), the result is silently truncated, causing padded_m to be underestimated and the output buffer to be too small. A divisibility assertion would catch this much earlier with a clear error message.
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; | |
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; | |
| NVTE_CHECK(logical_shape_nvte.data[0] % num_tensors == 0, | |
| "Grouped GEMM swizzle expects logical_shape first dim to be divisible by num_tensors."); |
| bool input_is_compact; | ||
| if (input_scale_numel == input->num_tensors * padded_scale_elems) { | ||
| input_is_compact = false; | ||
| } else if (input_scale_numel == compact_total_scale_elems) { | ||
| input_is_compact = true; | ||
| } else { | ||
| NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, | ||
| "Grouped input columnwise_scale_inv size does not match expected packed size."); | ||
| NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, | ||
| "Grouped output columnwise_scale_inv size does not match expected packed size."); | ||
| NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), | ||
| " size does not match expected packed size (got ", input_scale_numel, | ||
| ", expected either ", input->num_tensors * padded_scale_elems, | ||
| " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); | ||
| } |
There was a problem hiding this comment.
Implicit contract on compact-buffer alignment is not validated
The compact_total_scale_elems formula assumes the upstream quantize kernel allocates the compact scale buffer with its total first dim rounded up to 128 (rowwise) or 4 (colwise). If a caller passes a "plain compact" buffer of size exactly num_tensors * m * padded_k (without trailing alignment slack), neither branch matches and NVTE_ERROR fires with a size-mismatch message that may be hard to diagnose.
Consider also accepting num_tensors * compact_scale_elems as a valid compact size, or documenting this alignment requirement in the error message.
|
@ksivaman Could you add a test exercising the change? |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci |
| size_t group_first_align; | ||
| if (rowwise) { | ||
| per_tensor_first_unpadded = M_per_tensor; | ||
| const size_t scale_K = (K_per_tensor + BLOCK - 1) / BLOCK; |
There was a problem hiding this comment.
We already have divide_round_up and round_up_to_nearest_multiple helpers for this. Could you please use them instead?
| const NVTEShape rs = input->rowwise_scale_inv_shape(); | ||
| zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr<uint8_t>(), | ||
| rs.data[0], rs.data[1], | ||
| M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE); |
| const NVTEShape cs = input->columnwise_scale_inv_shape(); | ||
| zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr<uint8_t>(), | ||
| cs.data[0], cs.data[1], | ||
| (M + BLOCK_SIZE - 1) / BLOCK_SIZE, K); |
| void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr; | ||
|
|
||
| if (rowwise) { | ||
| switch (vec_load_size) { |
There was a problem hiding this comment.
To avoid code duplication, I'd suggest replacing this switch with a macro similar to the one below:
#define TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(INTEGER_ELTS_NUM, type, ...) \
switch (INTEGER_ELTS_NUM) { \
case 1: { \
using type = int; \
{ __VA_ARGS__ } \
} break; \
case 2: { \
using type = int2; \
{ __VA_ARGS__ } \
} break; \
case 4: { \
using type = int4; \
{ __VA_ARGS__ } \
} break; \
default: { \
NVTE_ERROR("Unsupported number of integer elements ", INTEGER_ELTS_NUM, \
". Expected one of: 1, 2, or 4."); \
} \
}
| NVTE_ERROR("Not valid vec_load_size."); | ||
| } | ||
| } else { | ||
| switch (vec_load_size) { |
There was a problem hiding this comment.
And this switch too
| // Per-byte K masking is still needed when only part of the register is past | ||
| // original_K (i.e. row is in range but the K position spans the boundary). | ||
| if constexpr (IS_PADDED_K) { | ||
| for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { |
There was a problem hiding this comment.
Adding #pragma unroll here would help performance.
Description
Fix grouped MXFP8 swizzle when per-expert rows aren't a multiple of 128 and pad each expert's scales to (128, 4).
Type of change
Changes
Checklist: