[PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True #2936
[PyTorch] Main_Grad buffer isnt overwritten when overwrite_main_grad=True #2936vthumbe1503 wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Greptile SummaryThis PR fixes a bug in Confidence Score: 4/5Fix is logically correct and minimal; only a naming clarity concern remains. The core fix is sound — the GEMM accumulate flag retains the correct False value when overwriting, and grouped_wgrad now always wraps main_grad. Post-GEMM sections using fc_op._accumulate_into_main_grad are also correct. Score held at 4 due to no regression tests for the fix and the incomplete PR description/checklist. No files require special attention; the single changed file is straightforward. Important Files Changed
|
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: