Skip to content

run_prediction.py fails with MPS (with diff solution) #744

@oytuntez

Description

@oytuntez

When running with MPS, there is a compatibility issue that FastSurfer was using torch.float16 for the prediction tensor, but MPS (Metal Performance Shaders) on Apple Silicon has limited support for float16 operations, particularly for the add_ operation with alpha scaling. This fix was already applied to HypVINN, but we still need it in FastSurfer. When MPS, we need to use torch.float32.

Below is a diff if it helps... I would like to fork the repo and bring this as PR, but I couldn't do it yet.

diff --git forkSrcPrefix/FastSurferCNN/run_prediction.py forkDstPrefix/FastSurferCNN/run_prediction.py
index 277c3233fe2da41c5a91aac96653ad4c339bed97..d4c063e75a135e43a05b790a39814f056d440933 100644
--- forkSrcPrefix/FastSurferCNN/run_prediction.py
+++ forkDstPrefix/FastSurferCNN/run_prediction.py
@@ -387,9 +387,11 @@ class RunModelOnData:
             Predicted classes.
         """
         shape = orig_data.shape + (self.get_num_classes(),)
+        # Use float32 for MPS devices due to limited float16 support
+        dtype = torch.float32 if self.viewagg_device.type == "mps" else torch.float16
         kwargs = {
             "device": self.viewagg_device,
-            "dtype": torch.float16,
+            "dtype": dtype,
             "requires_grad": False,
         }
 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions