guard fuser grad checks on non-leaf nodes#2919
guard fuser grad checks on non-leaf nodes#2919CarlosGomes98 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
There was a problem hiding this comment.
I'm not sure if this is really addressing the root cause of the issue. Two problems:
- We aren't actually protecting against setting
requires_gradon non-leaf nodes. We're just skippingrequires_gradlogic whentorch.is_grad_enabled() == True. - Do we even want to skip setting
requires_gradon non-leaf nodes? The backward expects grads from each of the outputs, so we needrequires_gradfor autograd to do the right thing.
I think the right solution is smarter logic when setting requires_grad_. Maybe something like:
x_requires_grad = fuser.first_op_requiring_backward < fuser._num_basic_ops
if x_requires_grad != x.requires_grad:
x = x.detach()
if x_requires_grad:
x.requires_grad_()
# Or maybe only detach if x is a non-leaf node?
# Need to check if the CPU overhead of checking
# is worth saving the CPU overhead of detaching.
...
return xAnother approach would be changing our ops to always return leaf nodes. For example, here is the forward pass of MakeExtraOutput:
This would be changed to:
out = input_.detach()
return out, [(out,)] | for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): | ||
| for y in ys: | ||
| y.requires_grad_(idx >= fuser.first_op_requiring_backward) | ||
| if func_ctx is not None: |
There was a problem hiding this comment.
This logic is not intuitive. func_ctx is None when torch.is_grad_enabled() == False:
TransformerEngine/transformer_engine/pytorch/ops/fuser.py
Lines 504 to 509 in 0c2e7b0
It would be better to pass in is_grad_enabled as an arg so that we can be explicit and not rely on secret contracts.
There was a problem hiding this comment.
As I understand it, the real issue is that when the forward_func is .apply, we are free to set requires_grad_ on returned tensors. But when it is .forward, we cannot mutate this state on non-leaf tensors.
When torch.is_grad_enabled() is false, we bypass .apply and call .forward directly with no func_ctx. In that path there is no OperationFuserAutogradFunction node registered, so no fuser backward will run. So I think this
The backward expects grads from each of the outputs, so we need requires_grad for autograd to do the right thing
is not true because we cannot run backward() through it.
I think it makes sense to pass this as an explicit argument as you say, instead of relying on the func_ctx being None. But I think the current logic is correct
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: