Skip to content

fix: skip generating fp8 & mxfp8 checkpoints if unsupported#2935

Open
kainzhong wants to merge 3 commits intoNVIDIA:mainfrom
kainzhong:fix/skip_generating_unsupported_ckpt
Open

fix: skip generating fp8 & mxfp8 checkpoints if unsupported#2935
kainzhong wants to merge 3 commits intoNVIDIA:mainfrom
kainzhong:fix/skip_generating_unsupported_ckpt

Conversation

@kainzhong
Copy link
Copy Markdown
Collaborator

Description

Using python3 tests/pytorch/test_checkpoint.py --save-checkpoint all to generate checkpoints would fail on some hardwares that don't support fp8 and mxfp8. This PR refactors the test script so it checks if the checkpoint type is supported first.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Make test_checkpoint.pycheck if the checkpoint to be generated is supported on the current hardware.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR extracts duplicate quantization-availability checks from test_module and _save_checkpoint into a shared should_skip() helper, allowing _save_checkpoint to skip unsupported fp8/mxfp8 checkpoint generation gracefully. The refactoring is clean and the logic is correct within the pytest session context.

Confidence Score: 4/5

Safe to merge for the pytest use case; the script invocation path has a pre-existing caveat already captured in review history.

The refactoring correctly deduplicates the skip logic and works properly within a pytest session. The one substantive concern about pytest.skip() being called in a non-pytest (CLI main()) context has already been raised in a prior review thread, so the score reflects a clean P2-only new finding.

tests/pytorch/test_checkpoint.py — specifically the interaction between _save_checkpoint and main() when run as a standalone script.

Important Files Changed

Filename Overview
tests/pytorch/test_checkpoint.py Refactors duplicate skip-logic into a shared should_skip() helper and moves the check into _save_checkpoint; the pytest context and type-annotation nit are the remaining concerns.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[_save_checkpoint or test_module] --> B[should_skip name]
    B --> C{dot in name?}
    C -- No --> D[return None]
    C -- Yes --> E{fp8 and not fp8_available?}
    E -- Yes --> F[return reason_for_no_fp8]
    E -- No --> G{mxfp8 and not mxfp8_available?}
    G -- Yes --> H[return reason_for_no_mxfp8]
    G -- No --> D
    D --> I{skip_reason is None?}
    I -- Yes --> J[Proceed with checkpoint]
    I -- No --> K[pytest.skip skip_reason]
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into fix/skip_genera..." | Re-trigger Greptile

Comment on lines +113 to +115
skip_reason = should_skip(name)
if skip_reason is not None:
pytest.skip(skip_reason)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 pytest.skip() raises an exception when called outside pytest

pytest.skip() always raises _pytest.outcomes.Skipped (a BaseException subclass), even when called outside a pytest session. When _save_checkpoint is invoked from main() (i.e., python3 test_checkpoint.py --save-checkpoint all), this exception propagates out of _save_checkpoint, exits the for loop immediately, and the script terminates with an uncaught exception — so any modules after the first skipped one are never processed. The PR description says this script invocation was the broken use case being fixed, but the fix doesn't handle it correctly.

A minimal repair is to return early (print + return) when the call originates from the script context, or catch pytest.skip.Exception in main():

if args.save_checkpoint == "all":
    for name in _TestLoadCheckpoint_name_list:
        try:
            TestLoadCheckpoint._save_checkpoint(name, checkpoint_dir=checkpoint_dir)
        except pytest.skip.Exception as e:
            print(f"Skipping checkpoint for {name}: {e}")

@kainzhong kainzhong force-pushed the fix/skip_generating_unsupported_ckpt branch from 74bb532 to 8639210 Compare April 27, 2026 22:52
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@kainzhong kainzhong force-pushed the fix/skip_generating_unsupported_ckpt branch from 8639210 to b498738 Compare April 27, 2026 22:55
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree with this approach. If you specify --save-checkpoint all, you should get all the checkpoints. We don't need to put much effort into making TestLoadCheckpoint._save_checkpoint fully adaptable to all systems, since it's not run in the test suite and is actually a special utility for updating checkpoints in our CI infrastructure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants