[PyTorch] Add distributed Muon optimizer#2920
[PyTorch] Add distributed Muon optimizer#2920vcherepanov-nv wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds a Confidence Score: 5/5Safe to merge; all new findings are P2 suggestions and the core distributed math is correct. The optimizer's distributed normalization, transpose handling, Nesterov/HeavyBall update, and weight-decay branches are all correct and consistent with the reference implementation in the test. Previously flagged P1s are either fixed (closure/enable_grad) or noted in prior threads. The only new findings are P2: a documentation gap about rank-symmetric gradient availability and incomplete scale-mode test coverage. Neither blocks correctness in the intended tensor-parallel use case. transformer_engine/pytorch/optimizers/muon.py — collective-deadlock documentation; tests/pytorch/distributed/run_muon_optimizer.py — scale_mode coverage gap Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant MuonOptimizer
participant _orthogonalize
participant _distributed_normalize_p2_
participant newton_schulz
Caller->>MuonOptimizer: step()
loop for each param with grad
MuonOptimizer->>MuonOptimizer: apply weight decay (decoupled or L2)
MuonOptimizer->>MuonOptimizer: momentum_buffer.lerp_(grad, 1-β)
MuonOptimizer->>MuonOptimizer: compute nesterov/non-nesterov update
MuonOptimizer->>_orthogonalize: update, partition_dim, ...
_orthogonalize->>_orthogonalize: clone + optional transpose
_orthogonalize->>_distributed_normalize_p2_: orth_grad
_distributed_normalize_p2_-->>_distributed_normalize_p2_: dist.all_reduce(norm_sq)
_distributed_normalize_p2_->>_orthogonalize: x /= global_norm
_orthogonalize->>newton_schulz: orth_grad, CusolverMpCtx
newton_schulz-->>newton_schulz: distributed NS iterations
newton_schulz->>_orthogonalize: orth_grad (orthogonalized)
_orthogonalize->>_orthogonalize: optional un-transpose + scale
_orthogonalize->>MuonOptimizer: orth_update
MuonOptimizer->>MuonOptimizer: p.add_(orth_update, alpha=-lr)
end
MuonOptimizer->>Caller: loss
Reviews (2): Last reviewed commit: "Fix Muon closure and reference test" | Re-trigger Greptile |
| def step(self, closure=None): | ||
| """Perform a single optimization step.""" | ||
| loss = None | ||
| if closure is not None: | ||
| loss = closure() | ||
|
|
There was a problem hiding this comment.
Closure called inside
@torch.no_grad(), preventing gradient computation
closure() is invoked while torch.no_grad() is active. Any loss.backward() call inside the closure will silently produce zero/no gradients. The standard PyTorch pattern (used in SGD, Adam, etc.) is to wrap the closure in with torch.enable_grad():.
| def step(self, closure=None): | |
| """Perform a single optimization step.""" | |
| loss = None | |
| if closure is not None: | |
| loss = closure() | |
| @torch.no_grad() | |
| def step(self, closure=None): | |
| """Perform a single optimization step.""" | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() |
| scale_mode: str, | ||
| extra_scale_factor: float, | ||
| eps: float, | ||
| ) -> torch.Tensor: | ||
| global_shape = [grad.size(0), grad.size(1)] | ||
| global_shape[partition_dim] *= world_size |
There was a problem hiding this comment.
Reference
global_shape incorrectly scales an already-full tensor
_reference_orthogonalize receives the full matrix (shape full_shape) but then multiplies global_shape[partition_dim] by world_size a second time. For partition_dim=1 with world_size=2 and full_shape=(96, 128) this gives global_shape=[96, 256], so get_muon_scale_factor returns max(96,256)^0.5 = 16. The optimizer, operating on the shard (96, 64), correctly reconstructs global_shape=[96, 128] and computes max(96,128)^0.5 ≈ 11.3. This √2 discrepancy means the reference cannot correctly validate the optimizer's output.
The global_shape[partition_dim] *= world_size line should be removed since the input is already the full matrix.
| if mode == "unit_rms_norm": | ||
| return (size_out / size_in) ** 0.5 |
There was a problem hiding this comment.
unit_rms_norm mode can divide by zero when size_in == 0
(size_out / size_in) ** 0.5 raises ZeroDivisionError when size_in is 0. While the optimizer validates that the partition dimension is non-empty, it doesn't ensure the other dimension is non-zero. Consider adding a guard or documenting that both dimensions must be strictly positive.
| if group["nesterov"]: | ||
| update = grad.lerp(momentum_buffer, group["momentum"]) | ||
| else: | ||
| update = momentum_buffer |
There was a problem hiding this comment.
Non-Nesterov
update is an alias to momentum_buffer, not a copy
update = momentum_buffer holds a reference. If _orthogonalize ever modifies its input in-place in a future refactor, the momentum buffer will be silently corrupted. _orthogonalize currently clones the input immediately so this is safe today, but a defensive .clone() or comment would make the intent explicit.
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
skyw
left a comment
There was a problem hiding this comment.
I'd advice NOT to expose it in public API. Keeping it in test only if that is the purpose.
Having an optimizer with most code copied invites fragmentation.
Before this, all optimizer TE provides are more optimized fused version. I'd say a highly optimized Fused Muon with similar concept can be justified, but would need more consideration because it has more dependencies on other part of the training pipeline than elementwise optimizers.
| on tensor-parallel parameter shards. The local parameter shard must represent a | ||
| partition of a logical 2D matrix across the provided NCCL process group. | ||
|
|
||
| Args: |
There was a problem hiding this comment.
Q: Does TE use numpy style docstring instead of Google style?
|
|
||
| def __init__( | ||
| self, | ||
| params: Iterable[torch.nn.Parameter | dict], |
There was a problem hiding this comment.
Nit: The type here doesn't match PyTorch internal. Should be fine for the purpose of this class.
| scale_mode: MuonScaleT = "spectral", | ||
| extra_scale_factor: float = 1.0, | ||
| process_group: Optional[dist.ProcessGroup] = None, | ||
| partition_dim: int = 1, |
| raise ValueError(f"Invalid weight_decay value: {weight_decay}") | ||
| if num_ns_steps < 1: | ||
| raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") | ||
| if partition_dim not in (0, 1): |
There was a problem hiding this comment.
Q: Does this class intend to support non-distributed case? partition_dim would be -1 in TE in such case.
|
|
||
| if process_group is None: | ||
| if not dist.is_initialized(): | ||
| raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.") |
There was a problem hiding this comment.
Same question above regarding single GPU support.
| if process_group is None: | ||
| if not dist.is_initialized(): | ||
| raise RuntimeError("MuonOptimizer requires torch.distributed to be initialized.") | ||
| process_group = dist.group.WORLD |
There was a problem hiding this comment.
Suggestion: This silent behavior is dangerous. If user forgot to pass the correct TP group, wrong group will be used.
| eps: float, | ||
| ) -> torch.Tensor: | ||
| self._validate_param(grad, partition_dim) | ||
| world_size = dist.get_world_size(self.process_group) |
There was a problem hiding this comment.
Some suggestion as above. The silent behavior of None process group falling back to default is dangerous. (Understand it is from PyTorch for historical reasons)
| global_shape[partition_dim] *= world_size | ||
|
|
||
| orth_grad = grad.clone() | ||
| transposed = partition_dim == 0 |
There was a problem hiding this comment.
Attn: This is from common Row and Column wise tensor parallelism in most LLM. It would be sub optimal for anything other than that. Add comment if the assumption is made.
Description
Add a distributed Muon optimizer, based on newton_schulz orthogonalization
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: