Skip to content

Fix CUDA graph parameter grad lifetime#2937

Open
buptzyb wants to merge 2 commits intoNVIDIA:mainfrom
buptzyb:fix/cudagraph-wgrad-lifetime
Open

Fix CUDA graph parameter grad lifetime#2937
buptzyb wants to merge 2 commits intoNVIDIA:mainfrom
buptzyb:fix/cudagraph-wgrad-lifetime

Conversation

@buptzyb
Copy link
Copy Markdown
Contributor

@buptzyb buptzyb commented Apr 28, 2026

Summary

Fix CUDA graph replay so parameter gradients returned from Graphed.backward do not expose CUDA graph static buffers to downstream autograd users.

The fix clones returned parameter gradients before handing them back to autograd, while preserving the existing aliasing behavior for delayed-wgrad parameters marked with skip_backward_post_hook.

Root Cause

When CUDA graph replay returns parameter grad slots directly from static graph buffers, downstream autograd users can retain references to buffers that are overwritten by later graph replays. This can corrupt retained grads or break gradient accumulation semantics.

This is related to PyTorch issue pytorch/pytorch#181723.

Changes

  • Detect parameter grad slots in the graphed autograd input surface.
  • Clone returned non-delayed-wgrad parameter grads before returning from Graphed.backward.
  • Allow reused graph input/output buffer mode to weak-ref current parameter grad static buffers after capture because returned grads are now cloned.
  • Add CUDA graph tests for owned returned parameter grads, accumulation, delayed-wgrad alias preservation, and reused buffer interleaved pipeline replay.

Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb force-pushed the fix/cudagraph-wgrad-lifetime branch from 4cc8b89 to 6e16c63 Compare April 28, 2026 08:18
@buptzyb buptzyb marked this pull request as ready for review April 28, 2026 08:20
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR fixes a CUDA graph replay bug where returned parameter gradients aliased the graph's static input buffers, allowing later replays to silently overwrite tensors that downstream autograd consumers (e.g. gradient hooks, param.grad accumulation) still hold references to. The fix snapshots which gradient slots correspond to non-delayed-wgrad parameters at capture time (_returned_param_grad_slots), then clones those slots before returning from Graphed.backward; the _reuse_graph_input_output_buffers path is updated to weak-ref param grad static buffers immediately after capture because the clone-on-return guarantee makes it safe to do so. Four new unit tests cover ownership, accumulation, delayed-wgrad alias preservation, and the capture-time policy snapshot; the interleaved pipeline test is also extended to compare final model weights, not just forward outputs.

Confidence Score: 5/5

Safe to merge; the clone-on-return logic is correct, the capture-time snapshotting is sound, and all changed code paths are covered by new tests.

No P0 or P1 findings. The only comment is a P2 style concern about using rtol=0/atol=0 in a float32 accumulation test that compares eager einsum against cuBLAS-backed CUDA graph ops, which could be fragile on some hardware. Core logic in graph.py is well-reasoned and consistent with the existing weak-ref memory management pattern.

No files require special attention; test_cuda_graphs.py has the minor float-tolerance concern noted above.

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Adds _returned_param_grad_slots helper snapshotted at capture time; clones returned param grads in Graphed.backward to prevent static buffer aliasing; extends _reuse_graph_input_output_buffers path to weak-ref param grad slots after clone-on-return guarantee is established.
tests/pytorch/test_cuda_graphs.py Adds four new unit tests for param grad ownership, accumulation, delayed-wgrad alias preservation, and capture-time policy snapshotting; extends interleaved pipeline test to compare final weights in addition to forward outputs; adds reused-buffer pipeline variant.

Sequence Diagram

sequenceDiagram
    participant Capture as Capture Time
    participant Replay as Graphed.backward
    participant Autograd as PyTorch Autograd

    Capture->>Capture: _returned_param_grad_slots(static_grad_inputs, module_params)
    Note over Capture: Snapshot skip_backward_post_hook per param into bool tuple stored in closure

    alt _reuse_graph_input_output_buffers=True
        Capture->>Capture: make_weak_ref(param_grad_slot) for each True slot
        Note over Capture: Python tensor dropped but CUDA graph memory pool still holds buffer
    end

    Replay->>Replay: bwd_graph.replay()
    Note over Replay: CUDA graph writes gradient into static param grad buffer

    loop each static_grad_input at idx
        alt grad_input is None
            Replay->>Autograd: None
        else returned_param_grad_slots idx is True
            Replay->>Replay: grad_input.detach().clone()
            Replay->>Autograd: fresh clone
            Note over Replay: Subsequent replays overwrite static buffer but param.grad clone is unaffected
        else activation or non-param grad
            Replay->>Autograd: grad_input.detach()
        end
    end

    Autograd->>Autograd: accumulate_grad accumulates clone into param.grad
Loading

Reviews (4): Last reviewed commit: "Address CUDA graph grad lifetime review ..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/graph.py Outdated
Comment on lines +410 to +417
def _is_returned_param_grad_slot(idx, static_grad_inputs, module_params):
"""Return whether a static grad slot is consumed through Graphed.backward."""
module_param_start = len(static_grad_inputs) - len(module_params)
if idx < module_param_start:
return False
return not getattr(
module_params[idx - module_param_start], "skip_backward_post_hook", False
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Timing inconsistency between capture and replay attribute reads

_is_returned_param_grad_slot reads skip_backward_post_hook live at both capture time (line 748) and replay time (line 945). If a caller flips the attribute between those two points, the weak-ref decision at capture and the clone decision at replay get out of sync.

Specifically, if the attribute was False at capture (→ static buffer was weak-refed in the _reuse_graph_input_output_buffers path) but True at replay (→ code calls .detach() instead of .detach().clone()), the returned tensor is a detached view of an already-released weak-ref buffer whose memory may have been reused. Snapshotting the skip_backward_post_hook state once at capture time and storing it alongside the static grad slot (or asserting it is unchanged at replay) would make the contract explicit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4077b85. The parameter-grad clone policy is now snapshotted at capture time and passed into Graphed.backward, so replay no longer re-reads skip_backward_post_hook. Added test_make_graphed_callables_snapshots_parameter_grad_clone_policy to cover changing the attribute after capture.

Comment on lines +889 to +906
def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers(
*,
model_config: str = "small",
dtype: torch.dtype = torch.float16,
) -> None:
"""Test CUDA graphs with reused input/output buffers."""
model_config = model_configs[model_config]
kwargs = dict(model_config=model_config, dtype=dtype)
outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=False,
**kwargs,
)
graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism(
with_graph=True,
reuse_graph_input_output_buffers=True,
**kwargs,
)
assert_all_equal(outputs, graph_outputs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Reused-buffer test only validates forward outputs, not gradient correctness

test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers compares output_snapshots (forward tensors cloned before the corresponding backward) against the eager baseline. If the clone-on-return logic in Graphed.backward had a bug specifically in the _reuse_graph_input_output_buffers + pipeline path (e.g., gradient accumulation or an incorrect static buffer being read), weights would diverge but the test would still pass. A weight-equality check after one full schedule would strengthen confidence in the gradient path for this mode.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 4077b85. The interleaved pipeline helper now returns final weights in addition to outputs, and the reused-buffer test compares graph/eager final weights to cover gradient correctness. Full tests/pytorch/test_cuda_graphs.py passed on H100: 415 passed, 423 skipped.

@buptzyb buptzyb force-pushed the fix/cudagraph-wgrad-lifetime branch 2 times, most recently from 441e419 to beff9c1 Compare April 28, 2026 09:40
Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb force-pushed the fix/cudagraph-wgrad-lifetime branch from beff9c1 to 4077b85 Compare April 28, 2026 13:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant