Skip to content

[PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True #2936

Open
vthumbe1503 wants to merge 2 commits intoNVIDIA:mainfrom
vthumbe1503:delay_wgrad_bug
Open

[PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True #2936
vthumbe1503 wants to merge 2 commits intoNVIDIA:mainfrom
vthumbe1503:delay_wgrad_bug

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 added bug Something isn't working 2.15.0 labels Apr 28, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@vthumbe1503 vthumbe1503 changed the title Main_Grad buffer isnt overwritten when overwrite_main_grad=True [Pytorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True Apr 28, 2026
@vthumbe1503 vthumbe1503 changed the title [Pytorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True [PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True Apr 28, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR fixes a bug in _compute_grad_params where the single_grouped_weight path silently fell through to a fresh buffer allocation instead of writing to main_grad when overwrite_main_grad=True. By removing the if accumulate_into_main_grad: guard, grouped_wgrad is now always backed by main_grad when fc_op._accumulate_into_main_grad is active, and the GEMM correctly receives accumulate=False (via the unchanged local variable) so the buffer is overwritten rather than accumulated into.

Confidence Score: 4/5

Fix is logically correct and minimal; only a naming clarity concern remains.

The core fix is sound — the GEMM accumulate flag retains the correct False value when overwriting, and grouped_wgrad now always wraps main_grad. Post-GEMM sections using fc_op._accumulate_into_main_grad are also correct. Score held at 4 due to no regression tests for the fix and the incomplete PR description/checklist.

No files require special attention; the single changed file is straightforward.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Bug fix: removes the conditional guard in the single_grouped_weight path so grouped_wgrad is always backed by main_grad when _accumulate_into_main_grad is set, regardless of overwrite_main_grad; post-GEMM sections consistently switched to fc_op._accumulate_into_main_grad. No tests added.

Comments Outside Diff (1)

  1. transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py, line 175-225 (link)

    P2 Stale local variable still driving GEMM accumulate flag

    accumulate_into_main_grad (local, lines 175/201) is still the value threaded into gemm_fn via accumulate=accumulate_into_main_grad (lines 216/225). The post-GEMM sections have been switched to fc_op._accumulate_into_main_grad, but the GEMM kernel itself has not. This is actually correct — the GEMM must use False when overwrite_main_grad=True so it overwrites rather than accumulates — but now the local variable is serving a narrower purpose than its name implies ("should GEMM accumulate?" vs. "are we writing into main_grad at all?"). A rename to something like gemm_accumulate would make the intent unambiguous and prevent future readers from assuming the two variables are always in sync.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.15.0 bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant