Fix CUDA graph parameter grad lifetime#2937
Conversation
Signed-off-by: Robin Zhang <robinz@nvidia.com>
4cc8b89 to
6e16c63
Compare
Greptile SummaryThis PR fixes a CUDA graph replay bug where returned parameter gradients aliased the graph's static input buffers, allowing later replays to silently overwrite tensors that downstream autograd consumers (e.g. gradient hooks, Confidence Score: 5/5Safe to merge; the clone-on-return logic is correct, the capture-time snapshotting is sound, and all changed code paths are covered by new tests. No P0 or P1 findings. The only comment is a P2 style concern about using rtol=0/atol=0 in a float32 accumulation test that compares eager einsum against cuBLAS-backed CUDA graph ops, which could be fragile on some hardware. Core logic in graph.py is well-reasoned and consistent with the existing weak-ref memory management pattern. No files require special attention; test_cuda_graphs.py has the minor float-tolerance concern noted above. Important Files Changed
Sequence DiagramsequenceDiagram
participant Capture as Capture Time
participant Replay as Graphed.backward
participant Autograd as PyTorch Autograd
Capture->>Capture: _returned_param_grad_slots(static_grad_inputs, module_params)
Note over Capture: Snapshot skip_backward_post_hook per param into bool tuple stored in closure
alt _reuse_graph_input_output_buffers=True
Capture->>Capture: make_weak_ref(param_grad_slot) for each True slot
Note over Capture: Python tensor dropped but CUDA graph memory pool still holds buffer
end
Replay->>Replay: bwd_graph.replay()
Note over Replay: CUDA graph writes gradient into static param grad buffer
loop each static_grad_input at idx
alt grad_input is None
Replay->>Autograd: None
else returned_param_grad_slots idx is True
Replay->>Replay: grad_input.detach().clone()
Replay->>Autograd: fresh clone
Note over Replay: Subsequent replays overwrite static buffer but param.grad clone is unaffected
else activation or non-param grad
Replay->>Autograd: grad_input.detach()
end
end
Autograd->>Autograd: accumulate_grad accumulates clone into param.grad
Reviews (4): Last reviewed commit: "Address CUDA graph grad lifetime review ..." | Re-trigger Greptile |
| def _is_returned_param_grad_slot(idx, static_grad_inputs, module_params): | ||
| """Return whether a static grad slot is consumed through Graphed.backward.""" | ||
| module_param_start = len(static_grad_inputs) - len(module_params) | ||
| if idx < module_param_start: | ||
| return False | ||
| return not getattr( | ||
| module_params[idx - module_param_start], "skip_backward_post_hook", False | ||
| ) |
There was a problem hiding this comment.
Timing inconsistency between capture and replay attribute reads
_is_returned_param_grad_slot reads skip_backward_post_hook live at both capture time (line 748) and replay time (line 945). If a caller flips the attribute between those two points, the weak-ref decision at capture and the clone decision at replay get out of sync.
Specifically, if the attribute was False at capture (→ static buffer was weak-refed in the _reuse_graph_input_output_buffers path) but True at replay (→ code calls .detach() instead of .detach().clone()), the returned tensor is a detached view of an already-released weak-ref buffer whose memory may have been reused. Snapshotting the skip_backward_post_hook state once at capture time and storing it alongside the static grad slot (or asserting it is unchanged at replay) would make the contract explicit.
There was a problem hiding this comment.
Addressed in 4077b85. The parameter-grad clone policy is now snapshotted at capture time and passed into Graphed.backward, so replay no longer re-reads skip_backward_post_hook. Added test_make_graphed_callables_snapshots_parameter_grad_clone_policy to cover changing the attribute after capture.
| def test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers( | ||
| *, | ||
| model_config: str = "small", | ||
| dtype: torch.dtype = torch.float16, | ||
| ) -> None: | ||
| """Test CUDA graphs with reused input/output buffers.""" | ||
| model_config = model_configs[model_config] | ||
| kwargs = dict(model_config=model_config, dtype=dtype) | ||
| outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( | ||
| with_graph=False, | ||
| **kwargs, | ||
| ) | ||
| graph_outputs = _test_cuda_graphs_with_interleaved_pipeline_parallelism( | ||
| with_graph=True, | ||
| reuse_graph_input_output_buffers=True, | ||
| **kwargs, | ||
| ) | ||
| assert_all_equal(outputs, graph_outputs) |
There was a problem hiding this comment.
Reused-buffer test only validates forward outputs, not gradient correctness
test_make_graphed_callables_with_interleaved_pipeline_parallelism_reused_buffers compares output_snapshots (forward tensors cloned before the corresponding backward) against the eager baseline. If the clone-on-return logic in Graphed.backward had a bug specifically in the _reuse_graph_input_output_buffers + pipeline path (e.g., gradient accumulation or an incorrect static buffer being read), weights would diverge but the test would still pass. A weight-equality check after one full schedule would strengthen confidence in the gradient path for this mode.
There was a problem hiding this comment.
Addressed in 4077b85. The interleaved pipeline helper now returns final weights in addition to outputs, and the reused-buffer test compares graph/eager final weights to cover gradient correctness. Full tests/pytorch/test_cuda_graphs.py passed on H100: 415 passed, 423 skipped.
441e419 to
beff9c1
Compare
Signed-off-by: Robin Zhang <robinz@nvidia.com>
beff9c1 to
4077b85
Compare
Summary
Fix CUDA graph replay so parameter gradients returned from
Graphed.backwarddo not expose CUDA graph static buffers to downstream autograd users.The fix clones returned parameter gradients before handing them back to autograd, while preserving the existing aliasing behavior for delayed-wgrad parameters marked with
skip_backward_post_hook.Root Cause
When CUDA graph replay returns parameter grad slots directly from static graph buffers, downstream autograd users can retain references to buffers that are overwritten by later graph replays. This can corrupt retained grads or break gradient accumulation semantics.
This is related to PyTorch issue pytorch/pytorch#181723.
Changes
Graphed.backward.