You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Description
Adds EnsembleWrapper to physicsnemo/experimental/models — a model-agnostic utility that wraps any physicsnemo.Module to provide ensemble-based epistemic uncertainty quantification.
Motivation
Deploying AI surrogates in industrial settings requires knowing when to trust the model. The 25.08 release introduced ensemble-based confidence estimation as a workflow in physicsnemo-cfd, scoped to the DoMINO automotive aerodynamics model (now moved to deprecated/). No model-agnostic equivalent exists in the core library. This PR promotes that pattern to a reusable first-class API.
What this adds
EnsembleWrapper — wraps any list of trained physicsnemo.Module instances and provides:
forward() — returns the ensemble mean, making it a drop-in replacement for any single model
predict_with_uncertainty() — returns mean, epistemic std, and all member predictions as an EnsemblePrediction dataclass
from_checkpoints() — constructs the ensemble directly from saved .pt checkpoint files
EnsemblePrediction dataclass — clean container for mean, std, and raw predictions
Example: examples/structural_mechanics/ensemble_uq/ — demonstrates the wrapper on a beam deflection surrogate, showing that epistemic std grows where the model is uncertain
Tests: test/models/test_ensemble_wrapper.py — 18 checks covering construction, forward pass shapes, mathematical correctness, and checkpoint round-trip
Usage
from physicsnemo.experimental.models.ensemble_wrapper import EnsembleWrapper
# Wrap independently trained members
ensemble = EnsembleWrapper(members)
# Drop-in forward (returns mean)
mean = ensemble(x)
# Uncertainty-aware inference
result = ensemble.predict_with_uncertainty(x)
result.mean # (B, out_features)
result.std # epistemic uncertainty
Checklist
I am familiar with the Contributing Guidelines
New or existing tests cover these changes
Follows MOD-001 (uses physicsnemo.Module base class via FullyConnected)
Follows MOD-002a (placed in physicsnemo/experimental/models/)
Follows MOD-003b (all docstrings use r""" prefix)
Follows EXT-001 (example-only deps declared in requirements.txt)
This PR adds EnsembleWrapper, a model-agnostic utility that wraps any list of physicsnemo.Module instances to provide ensemble-based epistemic uncertainty quantification, along with an EnsemblePrediction dataclass, a beam deflection example, and a test suite.
EnsembleWrapper (physicsnemo/experimental/models/ensemble_wrapper.py): core class with forward (returns mean, drop-in replacement), predict_with_uncertainty (mean + std + raw outputs), and from_checkpoints factory; correctly placed in the experimental tree per MOD-002a, but has a P1 discrepancy between the from_checkpoints docstring (claiming .mdlus support) and the raw state_dict-only implementation, plus several MOD-003/005/006/010 standard violations.
Tests (test/models/test_ensemble_wrapper.py): 18 checks covering shape, math correctness, and checkpoint round-trips, but missing reference-data non-regression tests (.pth files) and the PhysicsNeMo .mdlus checkpoint loading test required by MOD-008b/c.
Example (examples/structural_mechanics/ensemble_uq/): clean, self-contained beam deflection demo with correctly isolated requirements.txt.
New EnsembleWrapper module; has a P1 bug (misleading .mdlus support claim in from_checkpoints docstring vs raw state_dict-only implementation) and several MOD-00x standard violations (missing Forward/Outputs class sections, no shape validation, no jaxtyping, splatted kwargs).
test/models/test_ensemble_wrapper.py
Good coverage of shapes and math, but missing reference-data non-regression tests (.pth files) and .mdlus checkpoint round-trip test per MOD-008b/c.
Example-only dependency file listing matplotlib; correctly isolated from core package per EXT-001.
Comments Outside Diff (6)
physicsnemo/experimental/models/ensemble_wrapper.py, line 555-567 (link)
Misleading .mdlus checkpoint support claim
The docstring states checkpoints can be loaded if saved "via PhysicsNeMo's built-in checkpoint utilities," but the implementation calls torch.load(path, weights_only=True) directly — which cannot deserialize a .mdlus archive (a tar/zip format). Passing a PhysicsNeMo-native .mdlus path will raise an error at runtime rather than loading the weights. The docstring should either restrict the accepted format to raw state_dict files or implement proper .mdlus handling via physicsnemo.utils.checkpoint.
physicsnemo/experimental/models/ensemble_wrapper.py, line 403-457 (link)
MOD-003c: Missing required Forward and Outputs class docstring sections
Per MOD-003c, the class docstring must contain Parameters, Forward, and Outputs sections. EnsembleWrapper only documents Parameters, Raises, Examples, and See Also. The Forward section should describe the input tensor signature (currently documented only inside the forward method's own docstring), and the Outputs section should describe the return tensor shape/type at the class level.
physicsnemo/experimental/models/ensemble_wrapper.py, line 483-538 (link)
MOD-005: Missing tensor shape validation in public methods
Both forward and predict_with_uncertainty accept a tensor x without any shape validation guarded by torch.compiler.is_compiling(). Per MOD-005, all forward and public methods that accept tensor arguments must validate shapes at entry, wrapped in if not torch.compiler.is_compiling():. Without this, mismatched shapes produce cryptic errors deep in the member model graph.
physicsnemo/experimental/models/ensemble_wrapper.py, line 483-503 (link)
MOD-006: Missing jaxtyping tensor annotations
Per MOD-006, all tensor arguments in public methods must use jaxtyping annotations (e.g. Float[torch.Tensor, "batch ..."]). Since EnsembleWrapperMeta sets jit=False, TorchScript incompatibility is not a concern here. Both forward and predict_with_uncertainty currently use plain torch.Tensor annotations instead.
physicsnemo/experimental/models/ensemble_wrapper.py, line 544-551 (link)
MOD-010: Splatted **model_kwargs in class-level constructor
Per MOD-010, splatted kwargs should be avoided in model constructors because they obscure what parameters are accepted and risk naming collisions. from_checkpoints is effectively a factory constructor; accepting **model_kwargs makes the valid keyword arguments invisible to IDEs and static checkers. Consider accepting an explicit model_config: Optional[Dict[str, Any]] = None parameter instead.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
test/models/test_ensemble_wrapper.py, line 809-875 (link)
MOD-008b/c: Tests don't meet non-regression or checkpoint-loading standards
Two gaps relative to the coding standards:
MOD-008b: Tests verify output shapes and math properties but never compare tensor values against reference data saved in a .pth file. The standard requires loading pre-saved reference outputs and asserting torch.allclose against them so that silent numerical regressions are caught.
MOD-008c: The checkpoint test saves raw state_dict files and exercises EnsembleWrapper.from_checkpoints(), but MOD-008c requires testing the PhysicsNeMo checkpoint path — saving with model.save("*.mdlus") and loading with physicsnemo.Module.from_checkpoint("*.mdlus") — to validate that the full .mdlus serialization round-trip works correctly.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds EnsembleWrapper to physicsnemo/experimental/models — a model-agnostic utility that wraps any physicsnemo.Module to provide ensemble-based epistemic uncertainty quantification.
Motivation
Deploying AI surrogates in industrial settings requires knowing when to trust the model. The 25.08 release introduced ensemble-based confidence estimation as a workflow in physicsnemo-cfd, scoped to the DoMINO automotive aerodynamics model (now moved to deprecated/). No model-agnostic equivalent exists in the core library. This PR promotes that pattern to a reusable first-class API.
What this adds
EnsembleWrapper — wraps any list of trained physicsnemo.Module instances and provides:
forward() — returns the ensemble mean, making it a drop-in replacement for any single model
predict_with_uncertainty() — returns mean, epistemic std, and all member predictions as an EnsemblePrediction dataclass
from_checkpoints() — constructs the ensemble directly from saved .pt checkpoint files
EnsemblePrediction dataclass — clean container for mean, std, and raw predictions
Example: examples/structural_mechanics/ensemble_uq/ — demonstrates the wrapper on a beam deflection surrogate, showing that epistemic std grows where the model is uncertain
Tests: test/models/test_ensemble_wrapper.py — 18 checks covering construction, forward pass shapes, mathematical correctness, and checkpoint round-trip
Usage
Checklist