Skip to content

Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924

Open
ksivaman wants to merge 8 commits intoNVIDIA:mainfrom
ksivaman:pad_weight_scale_inv
Open

Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924
ksivaman wants to merge 8 commits intoNVIDIA:mainfrom
ksivaman:pad_weight_scale_inv

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

Fix grouped MXFP8 swizzle when per-expert rows aren't a multiple of 128 and pad each expert's scales to (128, 4).

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Make sure scaling factor inverses are 128x4 padded per tensor.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 24, 2026

Greptile Summary

This PR fixes grouped MXFP8 swizzle when per-expert rows are not a multiple of 128. The core change introduces a "compact" vs "per-tensor-padded" layout distinction: the quantize kernel writes a compact buffer (no padding between experts), while the swizzle output must be padded to (roundup(M,128), roundup(DIVUP(K,32),4)) per tensor for cuDNN consumption. The fix detects the input layout by comparing buffer sizes, sets separate input_stride_bytes/output_stride_bytes for the grouped kernels, and adds IS_PADDED_K/IS_PADDED_M compile-time template flags to prevent out-of-bounds loads past the compact per-tensor extent.

Confidence Score: 5/5

Safe to merge; no P0/P1 issues found; logic is correct and well-tested across edge cases.

All findings are P2 or below. The compact-layout detection, OOB-load prevention, and output buffer allocation are logically correct and consistent between swizzle.cu and swizzle.cpp. The test suite covers aligned, unaligned, and mixed shapes including the originally-failing workload shape.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/common/swizzle/swizzle.cu Core fix: adds compact-layout detection, separate input/output strides, and IS_PADDED_K/IS_PADDED_M template specializations to avoid OOB loads in the grouped uniform-shape kernels. Logic is correct; earlier padding_m/padding_k compound checks are cleanly separated into orthogonal guards.
transformer_engine/pytorch/csrc/extensions/swizzle.cpp Python-facing layer: allocates the output buffer in the correct per-tensor padded shape (roundup(M,128), roundup(DIVUP(K,32),4)) instead of using the compact input shape, so cuDNN sees the right strides between experts. The compute_padded_grouped_scale_shape lambda correctly mirrors the swizzle.cu padding formulas for both rowwise and colwise directions.
tests/cpp/operator/test_swizzle.cu Adds SwizzleGroupedCompactInputTestSuite covering aligned/unaligned M, unaligned K, and combinations; includes the originally-failing shape (2, 2880, 2880). The gather_compact_grouped_scale helper faithfully replicates the quantize kernel's layout, including the trailing group-level alignment.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["maybe_swizzle_grouped_tensor (swizzle.cpp)"]
    A -->|allocate output| B["compute_padded_grouped_scale_shape\nnum_tensors × roundup(M,128) × roundup(⌈K/32⌉,4)"]
    A --> C["nvte_swizzle_grouped_scaling_factors (swizzle.cu)"]

    C --> D{Detect input layout}
    D -->|numel == num_tensors × padded_scale_elems| E["input_is_compact = false\ninput_stride = padded_m × padded_k"]
    D -->|numel == compact_total_scale_elems| F["input_is_compact = true\ninput_stride = m × padded_k (rowwise)\nor ⌈M/32⌉ × padded_m (colwise)"]
    D -->|mismatch| G[NVTE_ERROR]

    E --> H[dispatch_swizzle_*_kernel_impl]
    F --> H

    H -->|IS_PADDED_M=true, row ≥ original_M| I[Zero register, skip __ldg]
    H -->|IS_PADDED_K=true, k_coord ≥ original_K| J[Zero register, skip __ldg]
    H -->|in-bounds| K["__ldg + per-byte boundary zeroing"]

    I --> L[Output: per-tensor padded layout\noutput_stride = padded_m × padded_k]
    J --> L
    K --> L
Loading

Reviews (3): Last reviewed commit: "Add test for swizzle + padding fusion" | Re-trigger Greptile

const auto logical_shape_nvte = input.logical_shape();
NVTE_CHECK(logical_shape_nvte.ndim >= 2,
"Grouped GEMM swizzle expects logical_shape with ndim >= 2.");
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent truncation when logical_shape_nvte.data[0] is not divisible by num_tensors

per_tensor_first_dim is computed with plain integer division. If logical_shape_nvte.data[0] is not an exact multiple of num_tensors (e.g. due to a caller bug or unexpected grouped layout), the result is silently truncated, causing padded_m to be underestimated and the output buffer to be too small. A divisibility assertion would catch this much earlier with a clear error message.

Suggested change
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors;
NVTE_CHECK(logical_shape_nvte.data[0] % num_tensors == 0,
"Grouped GEMM swizzle expects logical_shape first dim to be divisible by num_tensors.");

Comment on lines +2077 to 2087
bool input_is_compact;
if (input_scale_numel == input->num_tensors * padded_scale_elems) {
input_is_compact = false;
} else if (input_scale_numel == compact_total_scale_elems) {
input_is_compact = true;
} else {
NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems,
"Grouped input columnwise_scale_inv size does not match expected packed size.");
NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems,
"Grouped output columnwise_scale_inv size does not match expected packed size.");
NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"),
" size does not match expected packed size (got ", input_scale_numel,
", expected either ", input->num_tensors * padded_scale_elems,
" (per-tensor padded) or ", compact_total_scale_elems, " (compact)).");
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Implicit contract on compact-buffer alignment is not validated

The compact_total_scale_elems formula assumes the upstream quantize kernel allocates the compact scale buffer with its total first dim rounded up to 128 (rowwise) or 4 (colwise). If a caller passes a "plain compact" buffer of size exactly num_tensors * m * padded_k (without trailing alignment slack), neither branch matches and NVTE_ERROR fires with a size-mismatch message that may be hard to diagnose.

Consider also accepting num_tensors * compact_scale_elems as a valid compact size, or documenting this alignment requirement in the error message.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 24, 2026

@ksivaman Could you add a test exercising the change?

@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci

Copy link
Copy Markdown
Collaborator

@Oleg-Goncharov Oleg-Goncharov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall

size_t group_first_align;
if (rowwise) {
per_tensor_first_unpadded = M_per_tensor;
const size_t scale_K = (K_per_tensor + BLOCK - 1) / BLOCK;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have divide_round_up and round_up_to_nearest_multiple helpers for this. Could you please use them instead?

const NVTEShape rs = input->rowwise_scale_inv_shape();
zero_scale_inv_padding(input->rowwise_cpu_scale_inv_ptr<uint8_t>(),
rs.data[0], rs.data[1],
M, (K + BLOCK_SIZE - 1) / BLOCK_SIZE);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too

const NVTEShape cs = input->columnwise_scale_inv_shape();
zero_scale_inv_padding(input->columnwise_cpu_scale_inv_ptr<uint8_t>(),
cs.data[0], cs.data[1],
(M + BLOCK_SIZE - 1) / BLOCK_SIZE, K);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here

void* output_ptr = rowwise ? output->scale_inv.dptr : output->columnwise_scale_inv.dptr;

if (rowwise) {
switch (vec_load_size) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid code duplication, I'd suggest replacing this switch with a macro similar to the one below:

#define TRANSFORMER_ENGINE_VECTORIZED_LOAD_INTEGER_TYPE_SWITCH(INTEGER_ELTS_NUM, type, ...) \
  switch (INTEGER_ELTS_NUM) {                                                               \
    case 1: {                                                                               \
      using type = int;                                                                     \
      { __VA_ARGS__ }                                                                       \
    } break;                                                                                \
    case 2: {                                                                               \
      using type = int2;                                                                    \
      { __VA_ARGS__ }                                                                       \
    } break;                                                                                \
    case 4: {                                                                               \
      using type = int4;                                                                    \
      { __VA_ARGS__ }                                                                       \
    } break;                                                                                \
    default: {                                                                              \
      NVTE_ERROR("Unsupported number of integer elements ", INTEGER_ELTS_NUM,               \
                 ". Expected one of: 1, 2, or 4.");                                         \
    }                                                                                       \
  }

NVTE_ERROR("Not valid vec_load_size.");
}
} else {
switch (vec_load_size) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this switch too

// Per-byte K masking is still needed when only part of the register is past
// original_K (i.e. row is in range but the K position spans the boundary).
if constexpr (IS_PADDED_K) {
for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding #pragma unroll here would help performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants