Skip to content

Add EnsembleWrapper for model-agnostic uncertainty quantification#1644

Open
theosig wants to merge 2 commits into
NVIDIA:mainfrom
theosig:feature/ensemble-wrapper
Open

Add EnsembleWrapper for model-agnostic uncertainty quantification#1644
theosig wants to merge 2 commits into
NVIDIA:mainfrom
theosig:feature/ensemble-wrapper

Conversation

@theosig
Copy link
Copy Markdown

@theosig theosig commented May 13, 2026

Description
Adds EnsembleWrapper to physicsnemo/experimental/models — a model-agnostic utility that wraps any physicsnemo.Module to provide ensemble-based epistemic uncertainty quantification.

Motivation
Deploying AI surrogates in industrial settings requires knowing when to trust the model. The 25.08 release introduced ensemble-based confidence estimation as a workflow in physicsnemo-cfd, scoped to the DoMINO automotive aerodynamics model (now moved to deprecated/). No model-agnostic equivalent exists in the core library. This PR promotes that pattern to a reusable first-class API.

What this adds

EnsembleWrapper — wraps any list of trained physicsnemo.Module instances and provides:

forward() — returns the ensemble mean, making it a drop-in replacement for any single model
predict_with_uncertainty() — returns mean, epistemic std, and all member predictions as an EnsemblePrediction dataclass
from_checkpoints() — constructs the ensemble directly from saved .pt checkpoint files

EnsemblePrediction dataclass — clean container for mean, std, and raw predictions
Example: examples/structural_mechanics/ensemble_uq/ — demonstrates the wrapper on a beam deflection surrogate, showing that epistemic std grows where the model is uncertain
Tests: test/models/test_ensemble_wrapper.py — 18 checks covering construction, forward pass shapes, mathematical correctness, and checkpoint round-trip

Usage

from physicsnemo.experimental.models.ensemble_wrapper import EnsembleWrapper

# Wrap independently trained members
ensemble = EnsembleWrapper(members)

# Drop-in forward (returns mean)
mean = ensemble(x)

# Uncertainty-aware inference
result = ensemble.predict_with_uncertainty(x)
result.mean   # (B, out_features)
result.std    # epistemic uncertainty

Checklist

  • I am familiar with the Contributing Guidelines
  • New or existing tests cover these changes
  • Follows MOD-001 (uses physicsnemo.Module base class via FullyConnected)
  • Follows MOD-002a (placed in physicsnemo/experimental/models/)
  • Follows MOD-003b (all docstrings use r""" prefix)
  • Follows EXT-001 (example-only deps declared in requirements.txt)

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR adds EnsembleWrapper, a model-agnostic utility that wraps any list of physicsnemo.Module instances to provide ensemble-based epistemic uncertainty quantification, along with an EnsemblePrediction dataclass, a beam deflection example, and a test suite.

  • EnsembleWrapper (physicsnemo/experimental/models/ensemble_wrapper.py): core class with forward (returns mean, drop-in replacement), predict_with_uncertainty (mean + std + raw outputs), and from_checkpoints factory; correctly placed in the experimental tree per MOD-002a, but has a P1 discrepancy between the from_checkpoints docstring (claiming .mdlus support) and the raw state_dict-only implementation, plus several MOD-003/005/006/010 standard violations.
  • Tests (test/models/test_ensemble_wrapper.py): 18 checks covering shape, math correctness, and checkpoint round-trips, but missing reference-data non-regression tests (.pth files) and the PhysicsNeMo .mdlus checkpoint loading test required by MOD-008b/c.
  • Example (examples/structural_mechanics/ensemble_uq/): clean, self-contained beam deflection demo with correctly isolated requirements.txt.

Important Files Changed

Filename Overview
physicsnemo/experimental/models/ensemble_wrapper.py New EnsembleWrapper module; has a P1 bug (misleading .mdlus support claim in from_checkpoints docstring vs raw state_dict-only implementation) and several MOD-00x standard violations (missing Forward/Outputs class sections, no shape validation, no jaxtyping, splatted kwargs).
test/models/test_ensemble_wrapper.py Good coverage of shapes and math, but missing reference-data non-regression tests (.pth files) and .mdlus checkpoint round-trip test per MOD-008b/c.
examples/structural_mechanics/ensemble_uq/ensemble_uq_beam.py Self-contained beam deflection example demonstrating EnsembleWrapper; clean script with proper docstring and no issues.
examples/structural_mechanics/ensemble_uq/requirements.txt Example-only dependency file listing matplotlib; correctly isolated from core package per EXT-001.

Comments Outside Diff (6)

  1. physicsnemo/experimental/models/ensemble_wrapper.py, line 555-567 (link)

    P1 Misleading .mdlus checkpoint support claim

    The docstring states checkpoints can be loaded if saved "via PhysicsNeMo's built-in checkpoint utilities," but the implementation calls torch.load(path, weights_only=True) directly — which cannot deserialize a .mdlus archive (a tar/zip format). Passing a PhysicsNeMo-native .mdlus path will raise an error at runtime rather than loading the weights. The docstring should either restrict the accepted format to raw state_dict files or implement proper .mdlus handling via physicsnemo.utils.checkpoint.

  2. physicsnemo/experimental/models/ensemble_wrapper.py, line 403-457 (link)

    P2 MOD-003c: Missing required Forward and Outputs class docstring sections

    Per MOD-003c, the class docstring must contain Parameters, Forward, and Outputs sections. EnsembleWrapper only documents Parameters, Raises, Examples, and See Also. The Forward section should describe the input tensor signature (currently documented only inside the forward method's own docstring), and the Outputs section should describe the return tensor shape/type at the class level.

    File Used: CODING_STANDARDS/MODELS_IMPLEMENTATION.md (source)

  3. physicsnemo/experimental/models/ensemble_wrapper.py, line 483-538 (link)

    P2 MOD-005: Missing tensor shape validation in public methods

    Both forward and predict_with_uncertainty accept a tensor x without any shape validation guarded by torch.compiler.is_compiling(). Per MOD-005, all forward and public methods that accept tensor arguments must validate shapes at entry, wrapped in if not torch.compiler.is_compiling():. Without this, mismatched shapes produce cryptic errors deep in the member model graph.

    File Used: CODING_STANDARDS/MODELS_IMPLEMENTATION.md (source)

  4. physicsnemo/experimental/models/ensemble_wrapper.py, line 483-503 (link)

    P2 MOD-006: Missing jaxtyping tensor annotations

    Per MOD-006, all tensor arguments in public methods must use jaxtyping annotations (e.g. Float[torch.Tensor, "batch ..."]). Since EnsembleWrapperMeta sets jit=False, TorchScript incompatibility is not a concern here. Both forward and predict_with_uncertainty currently use plain torch.Tensor annotations instead.

    File Used: CODING_STANDARDS/MODELS_IMPLEMENTATION.md (source)

  5. physicsnemo/experimental/models/ensemble_wrapper.py, line 544-551 (link)

    P2 MOD-010: Splatted **model_kwargs in class-level constructor

    Per MOD-010, splatted kwargs should be avoided in model constructors because they obscure what parameters are accepted and risk naming collisions. from_checkpoints is effectively a factory constructor; accepting **model_kwargs makes the valid keyword arguments invisible to IDEs and static checkers. Consider accepting an explicit model_config: Optional[Dict[str, Any]] = None parameter instead.

    File Used: CODING_STANDARDS/MODELS_IMPLEMENTATION.md (source)

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  6. test/models/test_ensemble_wrapper.py, line 809-875 (link)

    P2 MOD-008b/c: Tests don't meet non-regression or checkpoint-loading standards

    Two gaps relative to the coding standards:

    1. MOD-008b: Tests verify output shapes and math properties but never compare tensor values against reference data saved in a .pth file. The standard requires loading pre-saved reference outputs and asserting torch.allclose against them so that silent numerical regressions are caught.
    2. MOD-008c: The checkpoint test saves raw state_dict files and exercises EnsembleWrapper.from_checkpoints(), but MOD-008c requires testing the PhysicsNeMo checkpoint path — saving with model.save("*.mdlus") and loading with physicsnemo.Module.from_checkpoint("*.mdlus") — to validate that the full .mdlus serialization round-trip works correctly.

    File Used: CODING_STANDARDS/MODELS_IMPLEMENTATION.md (source)

Reviews (1): Last reviewed commit: "feat(experimental): add model-agnostic E..." | Re-trigger Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant