Skip to content

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938

Open
hxbai wants to merge 2 commits intoNVIDIA:mainfrom
hxbai:swiglu_offset
Open

[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
hxbai wants to merge 2 commits intoNVIDIA:mainfrom
hxbai:swiglu_offset

Conversation

@hxbai
Copy link
Copy Markdown
Contributor

@hxbai hxbai commented Apr 28, 2026

Description

The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR makes the glu_linear_offset of ClampedSwiGLU configurable (default 1.0, matching GPT-OSS behavior) so that DeepSeek-V4's variant — which uses no offset — can be expressed without forking the kernel. The change propagates consistently through the CUDA kernels, FP8/MXFP8 cast kernels, PyTorch C++ extensions, JAX FFI bindings, and the pure-Python fallback in LayerNormMLP, with matching test coverage for both 1.0 and 0.0.

  • The signatures of nvte_clamped_swiglu and nvte_clamped_dswiglu in the public C header activation.h are changed by inserting glu_linear_offset before cudaStream_t. External callers compiled against the old header will silently pass the stream pointer as the float offset and an uninitialized value as the stream, causing undefined behavior. The PR checklist does not mark "Breaking change."

Confidence Score: 3/5

Correct logic throughout, but the public C API signature change is undeclared and can silently corrupt stream arguments in pre-compiled consumers.

A single P1 finding caps confidence at 4; the severity here (ABI break that passes junk as a CUDA stream) is significant enough to pull to 3 until the breaking-change implications are acknowledged and handled.

transformer_engine/common/include/transformer_engine/activation.h — public C API signature change needs explicit versioning / deprecation notice

Important Files Changed

Filename Overview
transformer_engine/common/include/transformer_engine/activation.h Public C API signatures for nvte_clamped_swiglu / nvte_clamped_dswiglu changed by inserting glu_linear_offset before stream — breaking change for external callers
transformer_engine/common/util/math.h ClampedSwiGLUParam struct gains glu_linear_offset field with default 1.0f, preserving backward-compatible defaults
transformer_engine/common/util/vectorized_pointwise.h Forward and backward kernels correctly replace hardcoded +1 / +1.0f with +p.glu_linear_offset; gradient math is unchanged and correct
transformer_engine/common/cast/fp8/gated_fp8.cuh Uses p.glu_linear_offset instead of hardcoded 1 in both the FP8 fused kernel forward/backward paths
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh Both MX-FP8 kernel variants updated consistently to use p.glu_linear_offset
transformer_engine/pytorch/ops/basic/swiglu.py ClampedSwiGLU and ScaledClampedQGeGLU both gain glu_linear_offset kwarg with default 1.0, properly threaded to tex calls
transformer_engine/jax/cpp_extensions/activation.py ClampedSwigluParams dataclass and clamped_linear lambda updated; hash now includes glu_linear_offset for correct JAX JIT cache keying
tests/pytorch/test_fusible_ops.py test_clamped_swiglu and test_scaled_clamped_qgeglu parametrized over glu_linear_offset in [1.0, 0.0]; reference formula updated correctly

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["ClampedSwiGLU / ScaledClampedQGeGLU\n(Python, glu_linear_offset=1.0 default)"]
    B["tex.clamped_swiglu / tex.clamped_dswiglu\n(pybind11, default 1.0f)"]
    C["C++ clamped_swiglu / clamped_dswiglu\n(pytorch/csrc/extensions/activation.cpp)"]
    D["nvte_clamped_swiglu / nvte_clamped_dswiglu\n(public C API — BREAKING change)"]
    E["ClampedSwiGLUParam{limit, alpha, glu_linear_offset}\n(common/util/math.h)"]
    F1["vectorized_pointwise.h kernel\n(val2 += p.glu_linear_offset)"]
    F2["gated_fp8.cuh kernel\n(gate_elt += p.glu_linear_offset)"]
    F3["gated_mxfp8.cuh kernel\n(gate_elt += p.glu_linear_offset)"]
    G["JAX ClampedSwigluParams\n(glu_linear_offset in hash + FFI dict)"]
    H["XLA FFI ClampedSwigluConfig\n(jax/csrc/extensions.h)"]
    I["ActLuFFI / DActLuDBiasQuantizeFFI\n(jax/csrc/extensions/activation.cpp)"]
    A --> B --> C --> D --> E
    E --> F1 & F2 & F3
    G --> H --> I --> D
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +339 to 341
* \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0).
* \param[in] stream CUDA stream used for the operation.
*/
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public C API change

nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant