[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938
[Common/PyTorch/JAX] make offset of ClampedSwiGLU configurable#2938hxbai wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
for more information, see https://pre-commit.ci
Greptile SummaryThis PR makes the
Confidence Score: 3/5Correct logic throughout, but the public C API signature change is undeclared and can silently corrupt stream arguments in pre-compiled consumers. A single P1 finding caps confidence at 4; the severity here (ABI break that passes junk as a CUDA stream) is significant enough to pull to 3 until the breaking-change implications are acknowledged and handled. transformer_engine/common/include/transformer_engine/activation.h — public C API signature change needs explicit versioning / deprecation notice Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["ClampedSwiGLU / ScaledClampedQGeGLU\n(Python, glu_linear_offset=1.0 default)"]
B["tex.clamped_swiglu / tex.clamped_dswiglu\n(pybind11, default 1.0f)"]
C["C++ clamped_swiglu / clamped_dswiglu\n(pytorch/csrc/extensions/activation.cpp)"]
D["nvte_clamped_swiglu / nvte_clamped_dswiglu\n(public C API — BREAKING change)"]
E["ClampedSwiGLUParam{limit, alpha, glu_linear_offset}\n(common/util/math.h)"]
F1["vectorized_pointwise.h kernel\n(val2 += p.glu_linear_offset)"]
F2["gated_fp8.cuh kernel\n(gate_elt += p.glu_linear_offset)"]
F3["gated_mxfp8.cuh kernel\n(gate_elt += p.glu_linear_offset)"]
G["JAX ClampedSwigluParams\n(glu_linear_offset in hash + FFI dict)"]
H["XLA FFI ClampedSwigluConfig\n(jax/csrc/extensions.h)"]
I["ActLuFFI / DActLuDBiasQuantizeFFI\n(jax/csrc/extensions/activation.cpp)"]
A --> B --> C --> D --> E
E --> F1 & F2 & F3
G --> H --> I --> D
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| * \param[in] glu_linear_offset Offset added to the linear component after clamping (default 1.0). | ||
| * \param[in] stream CUDA stream used for the operation. | ||
| */ |
There was a problem hiding this comment.
nvte_clamped_swiglu and nvte_clamped_dswiglu are public symbols declared in a versioned public header. Inserting glu_linear_offset before cudaStream_t is an ABI-breaking change: any external binary or shared library compiled against the old header will silently pass the stream pointer as the offset and a garbage value as the stream, leading to undefined behavior at runtime rather than a clean compile error if called via a pre-compiled library. This should be acknowledged as a breaking change in the PR checklist, and — if this library follows semantic versioning or a compatibility guarantee — a deprecation/transition path or version bump is needed.
Description
The previous ClampedSwiGLU follows GPT-OSS, which hard-coded the offset 1.0.
DeepSeek-V4 uses ClampedSwiGLU without alpha and offset.
This PR makes the offset of ClampedSwiGLU configurable to support DeepSeek-V4.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: