fix: skip generating fp8 & mxfp8 checkpoints if unsupported#2935
fix: skip generating fp8 & mxfp8 checkpoints if unsupported#2935kainzhong wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR extracts duplicate quantization-availability checks from Confidence Score: 4/5Safe to merge for the pytest use case; the script invocation path has a pre-existing caveat already captured in review history. The refactoring correctly deduplicates the skip logic and works properly within a pytest session. The one substantive concern about pytest.skip() being called in a non-pytest (CLI main()) context has already been raised in a prior review thread, so the score reflects a clean P2-only new finding. tests/pytorch/test_checkpoint.py — specifically the interaction between _save_checkpoint and main() when run as a standalone script. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[_save_checkpoint or test_module] --> B[should_skip name]
B --> C{dot in name?}
C -- No --> D[return None]
C -- Yes --> E{fp8 and not fp8_available?}
E -- Yes --> F[return reason_for_no_fp8]
E -- No --> G{mxfp8 and not mxfp8_available?}
G -- Yes --> H[return reason_for_no_mxfp8]
G -- No --> D
D --> I{skip_reason is None?}
I -- Yes --> J[Proceed with checkpoint]
I -- No --> K[pytest.skip skip_reason]
Reviews (4): Last reviewed commit: "Merge branch 'main' into fix/skip_genera..." | Re-trigger Greptile |
| skip_reason = should_skip(name) | ||
| if skip_reason is not None: | ||
| pytest.skip(skip_reason) |
There was a problem hiding this comment.
pytest.skip() raises an exception when called outside pytest
pytest.skip() always raises _pytest.outcomes.Skipped (a BaseException subclass), even when called outside a pytest session. When _save_checkpoint is invoked from main() (i.e., python3 test_checkpoint.py --save-checkpoint all), this exception propagates out of _save_checkpoint, exits the for loop immediately, and the script terminates with an uncaught exception — so any modules after the first skipped one are never processed. The PR description says this script invocation was the broken use case being fixed, but the fix doesn't handle it correctly.
A minimal repair is to return early (print + return) when the call originates from the script context, or catch pytest.skip.Exception in main():
if args.save_checkpoint == "all":
for name in _TestLoadCheckpoint_name_list:
try:
TestLoadCheckpoint._save_checkpoint(name, checkpoint_dir=checkpoint_dir)
except pytest.skip.Exception as e:
print(f"Skipping checkpoint for {name}: {e}")74bb532 to
8639210
Compare
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
8639210 to
b498738
Compare
There was a problem hiding this comment.
I don't agree with this approach. If you specify --save-checkpoint all, you should get all the checkpoints. We don't need to put much effort into making TestLoadCheckpoint._save_checkpoint fully adaptable to all systems, since it's not run in the test suite and is actually a special utility for updating checkpoints in our CI infrastructure.
Description
Using
python3 tests/pytorch/test_checkpoint.py --save-checkpoint allto generate checkpoints would fail on some hardwares that don't support fp8 and mxfp8. This PR refactors the test script so it checks if the checkpoint type is supported first.Type of change
Changes
Make
test_checkpoint.pycheck if the checkpoint to be generated is supported on the current hardware.Checklist: