-
Notifications
You must be signed in to change notification settings - Fork 451
Expand file tree
/
Copy pathfsdp2_buffer_patch.py
More file actions
413 lines (356 loc) · 18.8 KB
/
Copy pathfsdp2_buffer_patch.py
File metadata and controls
413 lines (356 loc) · 18.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Monkey-patch for accelerate's fsdp2_load_full_state_dict buffer handling.
Applicability (scope of this patch)
-----------------------------------
This is **not** needed for FSDP2 in general. It is required only for the narrow
combination of: **FSDP2 configured via an accelerate YAML config** (not torch-native
``ParallelismConfig``) **with** ``cpu_ram_efficient_loading=True``. Today that path is
forced by **models that require transformers 4.57.x** (their ``trust_remote_code`` code
predates transformers 5.x ``ParallelismConfig`` support) **and** are too large to load on
every rank — currently only **MiniMax-M2.7** (229B MoE). Models that run on transformers
5.x (Qwen, Llama, Nemotron, ...) use native ``ParallelismConfig`` (``dp_shard_size > 1``),
which handles buffers/dtypes correctly and never enters ``fsdp2_load_full_state_dict`` —
they need none of this. Gated off by default; activate with ``PATCH_FSDP2_BUFFERS_TF457=1``.
Problem
-------
accelerate's ``fsdp2_load_full_state_dict`` (called during model preparation
when ``cpu_ram_efficient_loading=True``) iterates ``model.state_dict()`` and
unconditionally accesses ``.device_mesh`` on every entry, assuming they are all
DTensors. After ``fully_shard()``, **parameters** become DTensors but
**persistent buffers** (e.g., rotary-embedding ``inv_freq``) remain plain
``torch.Tensor``. This crashes with::
AttributeError: 'Tensor' object has no attribute 'device_mesh'
Additionally, ``cpu_ram_efficient_loading`` causes a dtype divergence: rank 0
loads the model on CPU (inheriting the checkpoint's dtype, e.g. bfloat16) while
other ranks use ``meta`` device (defaulting to float32 for newly-added modules
like the DFlash head). After ``fully_shard()``, the DTensor dtypes differ
across ranks for these modules. Since ``dist.broadcast()`` requires matching
dtypes and element sizes on all ranks, broadcasting a bfloat16 tensor (2
bytes/elem) to a float32 receive buffer (4 bytes/elem) causes an NCCL deadlock.
Why we need FSDP2 via accelerate config (not ParallelismConfig)
---------------------------------------------------------------
MiniMax-M2.7's ``trust_remote_code`` model code requires **transformers 4.57.x**.
Transformers' native FSDP2 support via ``ParallelismConfig`` requires
**transformers 5.x**. So we fall back to configuring FSDP2 through an
``accelerate`` YAML config file (``accelerate_fsdp2.yaml``), which works with
transformers 4.57.x. We set ``dp_shard_size=1`` to prevent ``main.py`` from
creating a ``ParallelismConfig``, letting the accelerate config handle sharding.
Why we need cpu_ram_efficient_loading
-------------------------------------
MiniMax-M2.7 is a 229B MoE model (~230 GB in FP8). Each GB200 node has 4 GPUs
and ~800 GB system RAM. Without ``cpu_ram_efficient_loading``, all 4 ranks per
node load the model to CPU simultaneously (4 × 230 GB ≈ 920 GB), exceeding
system RAM and triggering OOM kills. With ``cpu_ram_efficient_loading``, only
rank 0 loads the model; other ranks initialize on ``meta`` device. The weights
are then broadcast via ``fsdp2_load_full_state_dict`` — which is where the bug
hits.
What this patch does
--------------------
1. Before accessing ``.device_mesh``, checks ``isinstance(entry, DTensor)``.
For non-DTensor entries (persistent buffers), broadcasts the raw tensor from
rank 0 without calling ``distribute_tensor()``.
2. All ranks iterate ``model.state_dict()`` (post-shard) in the same order so
broadcast calls match 1-to-1. Rank 0 looks up the full parameter **by key
name** from the pre-shard state dict — never by positional ``zip``, because
``model.to("meta")`` + ``fully_shard()`` can reorder keys.
3. **Dtype synchronization**: rank 0 broadcasts a dtype code for each entry
before the main loop. All ranks then use the same dtype for their broadcast
tensors. This fixes the dtype divergence caused by rank 0 loading in
bfloat16 while other ranks default to float32 for newly-added modules.
Accelerate config constraints (for reference)
----------------------------------------------
``accelerate_fsdp2.yaml`` also requires:
- ``fsdp_use_orig_params: true`` — without this, FSDP flattens all params into
FlatParameter, losing per-parameter ``requires_grad`` flags. The DFlash head
can't train because its gradients mix with frozen base model zeros.
- ``fsdp_transformer_layer_cls_to_wrap: MiniMaxM2DecoderLayer,DFlashModule`` —
DFlash head params at the model root must be in the wrap policy so they become
DTensors. Without this, ``fsdp2_load_full_state_dict`` also crashes.
- ``fsdp_sync_module_states: true`` — accelerate's launch validator requires it
when ``cpu_ram_efficient_loading`` is enabled, even though FSDP2 ignores it at
runtime (sets it to None with a warning).
Does NOT affect models on transformers 5.x
-------------------------------------------
This entire workaround exists ONLY because MiniMax-M2.7 requires
transformers 4.57.x. Models that support transformers 5.x (Qwen, Llama,
Nemotron, etc.) use ``ParallelismConfig`` natively by setting
``dp_shard_size > 1`` in the training args. That code path handles buffers
correctly and does not go through ``fsdp2_load_full_state_dict`` at all.
No accelerate config file, no ``PATCH_FSDP2_BUFFERS_TF457``, no
``OVERRIDE_TRANSFORMERS`` needed.
When to remove
--------------
This patch can be removed when EITHER of these happens:
1. MiniMax updates ``trust_remote_code`` for transformers 5.x, allowing native
``ParallelismConfig`` (which handles this correctly).
2. Upstream accelerate fixes ``fsdp2_load_full_state_dict`` to skip non-DTensor
entries. Track: https://github.com/huggingface/accelerate
Activation
----------
Set ``PATCH_FSDP2_BUFFERS_TF457=1`` in the environment to activate. Off by default.
Only needed in MiniMax-M2.7 pipeline YAMLs.
"""
import torch
# Dtype encoding for the broadcast dtype-sync step.
_DTYPE_TO_CODE = {
torch.float32: 0,
torch.bfloat16: 1,
torch.float16: 2,
torch.float8_e4m3fn: 3 if hasattr(torch, "float8_e4m3fn") else -1,
}
_CODE_TO_DTYPE = {v: k for k, v in _DTYPE_TO_CODE.items() if v >= 0}
def apply():
"""Patch fsdp2_load_full_state_dict if the buffer bug is present."""
try:
import accelerate.utils.fsdp_utils as fsdp_utils
from torch.distributed.tensor import DTensor
def _dtype_code(dt):
"""Map a dtype to its broadcast sync code, raising on anything unmapped.
Silently coercing an unknown dtype to fp32 would cast data on the wire (or
make NCCL refuse on an element-size mismatch), so fail loudly instead.
"""
code = _DTYPE_TO_CODE.get(dt)
if code is None or code < 0:
raise ValueError(
f"fsdp2_buffer_patch: unsupported dtype {dt} in the broadcast "
f"dtype-sync; add it to _DTYPE_TO_CODE."
)
return code
def _patched(accelerator, model, full_sd, cpu_offload=False):
import time
import torch.distributed as dist
from torch.distributed.tensor import distribute_tensor
meta_sharded_sd = model.state_dict()
sharded_sd = {}
n_total = len(meta_sharded_sd)
n_dtensor = sum(1 for v in meta_sharded_sd.values() if isinstance(v, DTensor))
n_buffer = n_total - n_dtensor
if accelerator.is_main_process:
print(
f"[fsdp2_buffer_patch] State dict: {n_total} entries "
f"({n_dtensor} DTensor, {n_buffer} buffer), full_sd: {len(full_sd)}"
)
else:
print(
f"[fsdp2_buffer_patch] State dict: {n_total} entries "
f"({n_dtensor} DTensor, {n_buffer} buffer)"
)
t0 = time.time()
# --- Step 0: broadcast dtype codes from rank 0 ---
# cpu_ram_efficient_loading causes rank 0 to load in bfloat16 while
# other ranks default to float32 for newly-added modules (DFlash).
# After fully_shard(), DTensor dtypes diverge across ranks.
# Broadcast rank 0's dtypes so all ranks use the same dtype for
# each broadcast tensor.
if accelerator.is_main_process:
dtype_codes = torch.tensor(
[_dtype_code(full_sd[name].dtype) for name in meta_sharded_sd],
dtype=torch.int32,
device=accelerator.device,
)
else:
dtype_codes = torch.empty(
n_total,
dtype=torch.int32,
device=accelerator.device,
)
dist.broadcast(dtype_codes, src=0, group=dist.group.WORLD)
broadcast_dtypes = [_CODE_TO_DTYPE[c.item()] for c in dtype_codes]
del dtype_codes
# Infer dtype/contiguity for cast — copied from upstream
def _infer_parameter_dtype(mdl, param_name, empty_param):
try:
old_param = mdl.get_parameter_or_buffer(param_name)
except AttributeError:
base, local = param_name.rsplit(".", 1)
old_param = getattr(mdl.get_submodule(base), local)
is_f8 = hasattr(torch, "float8_e4m3fn") and empty_param.dtype == torch.float8_e4m3fn
casting_dtype = (
old_param.dtype if (empty_param.dtype.is_floating_point and not is_f8) else None
)
return old_param is not None and old_param.is_contiguous(), casting_dtype
def _finish(st, contig, dtype, offload):
if dtype is not None:
st = st.to(dtype=dtype)
if contig:
st = st.contiguous()
if offload:
st = st.to("cpu")
return st
# --- Step 1: broadcast all entries ---
# All ranks iterate meta_sharded_sd in the same order to ensure
# identical broadcast sequences. Rank 0 looks up the full parameter
# by name — never positional zip (model.to("meta") + fully_shard()
# can reorder keys). broadcast_dtypes[idx] is used for the tensor
# dtype on ALL ranks to prevent NCCL deadlocks from dtype divergence.
for idx, (param_name, sharded_param) in enumerate(meta_sharded_sd.items()):
is_dtensor = isinstance(sharded_param, DTensor)
bcast_dtype = broadcast_dtypes[idx]
if not is_dtensor:
# Persistent buffer — broadcast raw, no distribute_tensor
if accelerator.is_main_process:
t = (
full_sd[param_name]
.detach()
.to(device=accelerator.device, dtype=bcast_dtype)
)
else:
t = torch.empty(
sharded_param.size(),
device=accelerator.device,
dtype=bcast_dtype,
)
dist.broadcast(t, src=0, group=dist.group.WORLD)
sharded_sd[param_name] = t
continue
device_mesh = sharded_param.device_mesh
if accelerator.is_main_process:
ft = (
full_sd[param_name]
.detach()
.to(device=device_mesh.device_type, dtype=bcast_dtype)
)
if isinstance(ft, DTensor):
ft = ft.to_local()
else:
ft = torch.empty(
sharded_param.size(),
device=device_mesh.device_type,
dtype=bcast_dtype,
)
dist.broadcast(ft, src=0, group=dist.group.WORLD)
st = distribute_tensor(ft, device_mesh, sharded_param.placements)
contig, _ = _infer_parameter_dtype(model, param_name, ft)
# Use bcast_dtype (from rank 0) instead of the model's local
# param dtype — with cpu_ram_efficient_loading, non-rank-0
# processes have fp32 meta-device params for DFlash, and
# _infer_parameter_dtype would incorrectly cast bf16 back to fp32.
is_f8 = hasattr(torch, "float8_e4m3fn") and bcast_dtype == torch.float8_e4m3fn
final_dtype = None if is_f8 else bcast_dtype
sharded_sd[param_name] = _finish(st, contig, final_dtype, cpu_offload)
elapsed = time.time() - t0
print(
f"[fsdp2_buffer_patch] Broadcast done in {elapsed:.1f}s, "
f"loading {len(sharded_sd)} entries into model..."
)
model.load_state_dict(sharded_sd, assign=True)
print(
f"[fsdp2_buffer_patch] State dict loaded successfully "
f"({time.time() - t0:.1f}s total)"
)
return model
fsdp_utils.fsdp2_load_full_state_dict = _patched
print("[fsdp2_buffer_patch] Patched fsdp2_load_full_state_dict for buffer compatibility")
except Exception as e:
print(f"[fsdp2_buffer_patch] Patch skipped: {e}")
_clip_grad_norm_call_count = 0
def _clip_grad_norm(parameters, max_norm, norm_type=2):
"""Clip gradient norms for FSDP2 DTensor parameters.
Bypasses DTensor dispatch (which deadlocks with partially-frozen models
on the accelerate FSDP2 path) by extracting local tensor shards and
doing an explicit all_reduce for the global norm.
Handles Shard (need all_reduce) and Replicate/regular (already global)
placements. Safe for DFlash-only and LoRA co-training.
"""
global _clip_grad_norm_call_count
import torch.distributed as dist
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Shard
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(parameters) # materialize generator
max_norm = float(max_norm)
norm_type = float(norm_type)
grads = [p.grad for p in parameters if p.grad is not None]
# Every rank MUST reach the all_reduce below: under FSDP2 sharding (and especially
# MoE + LoRA co-training) a rank can legitimately have no grads on a step — e.g. an
# expert that received no tokens, so the shard it owns gets no gradient. Early-
# returning here while other ranks call all_reduce would deadlock the job. So we
# never short-circuit before the collective; an empty-grad rank simply contributes a
# zero local norm and clips nothing.
if grads:
device = grads[0]._local_tensor.device if isinstance(grads[0], DTensor) else grads[0].device
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Shard DTensors hold partial data — need all_reduce for global norm.
# Replicate DTensors and regular tensors already hold full data.
sharded_norm_p = torch.tensor(0.0, device=device)
local_norm_p = torch.tensor(0.0, device=device)
n_sharded = 0
n_replicate = 0
n_regular = 0
for g in grads:
if isinstance(g, DTensor):
is_sharded = any(isinstance(p, Shard) for p in g.placements)
t = g._local_tensor.detach().to(torch.float32)
n = torch.linalg.vector_norm(t, norm_type)
if is_sharded:
sharded_norm_p += n.pow(norm_type)
n_sharded += 1
else:
local_norm_p += n.pow(norm_type)
n_replicate += 1
else:
n = torch.linalg.vector_norm(g.detach().to(torch.float32), norm_type)
local_norm_p += n.pow(norm_type)
n_regular += 1
# Symmetric across ranks: reached on every rank regardless of whether this rank had
# grads (see note above). Guard for the non-distributed case where local == global.
if dist.is_available() and dist.is_initialized():
dist.all_reduce(sharded_norm_p, op=dist.ReduceOp.SUM)
total_norm = (sharded_norm_p + local_norm_p).pow(1.0 / norm_type)
clip_coef = torch.clamp(max_norm / (total_norm + 1e-6), max=1.0)
# Debug: log computation breakdown on first 5 calls (no collectives — safe).
_clip_grad_norm_call_count += 1
_rank0 = not (dist.is_available() and dist.is_initialized()) or dist.get_rank() == 0
if _clip_grad_norm_call_count <= 5 and _rank0:
print(
f"[clip_grad_norm debug] call={_clip_grad_norm_call_count} "
f"total_norm={total_norm.item():.6f} "
f"sharded_norm_p={sharded_norm_p.item():.6f} local_norm_p={local_norm_p.item():.6f} "
f"grads={len(grads)} (sharded={n_sharded} replicate={n_replicate} regular={n_regular}) "
f"max_norm={max_norm} clip_coef={clip_coef.item():.6f}"
)
for g in grads:
if isinstance(g, DTensor):
g._local_tensor.mul_(clip_coef)
else:
g.mul_(clip_coef)
return total_norm
def patch_accelerator(accelerator):
"""Replace accelerator's clip_grad_norm_ with FSDP2-safe version."""
accelerator.clip_grad_norm_ = _clip_grad_norm
print(
"[fsdp2_buffer_patch] Patched accelerator.clip_grad_norm_ for FSDP2 DTensor compatibility"
)
def log_param_dtypes(model):
"""Debug aid: log per-rank parameter dtype counts (gated by DFLASH_LOG_PARAM_DTYPES=1).
Used to verify the FSDP2 dtype synchronization above — after ``fully_shard()`` params
are DTensors whose dtype lives on ``_local_tensor``. Off by default; this is purely
diagnostic and has no effect on training.
"""
import os
if os.environ.get("DFLASH_LOG_PARAM_DTYPES") != "1":
return
rank = int(os.environ.get("RANK", 0))
dtypes = {}
for name, p in model.named_parameters():
dt_key = str(p.dtype) if not hasattr(p, "_local_tensor") else str(p._local_tensor.dtype)
dtypes.setdefault(dt_key, []).append(name)
for dt, names in dtypes.items():
print(f"[dtype_check rank={rank}] {dt}: {len(names)} params (e.g. {names[0]})")