Loss.backward() dies

Hello I’ve been tinkering with torch on macOS (M3 Max 16cores, 128GB RAM, 40 gpu cores) on sequoia 15.3, torch version is a fork i installed from here, @ffda73cfecffd15b50a85745e98d0641a5583b58, due to transposed convolutions not working on the current stable & night branches. AFAIK it’s forked from torch-2.3.0.

I tried running a simple UNet with residual units (monai implementations 1.4.0). Running initially on jupyter the kernel would die without a reason. Then i wrote a minimum example on a script:

import torch
from monai.networks.nets import UNet
from torchsummary import summary

device = torch.device('mps')

q = 128
a = torch.rand((1,1,q,q,q)).to(device)
b = torch.rand((1,1,q,q,q)).to(device)

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128),
    strides=(2, 2, 2),
    num_res_units=2,
)
# print(summary(model, torch.rand((1,q,q,q)).shape))
model = model.to(device)

model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()

output = model(a)
loss = criterion(output, b)
print(loss.item())
loss.backward()
print('checkpoint')

This in return throws this error:

(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:307:0: error: 'mps.reshape' op the result shape is not compatible with the input shape
(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:307:0: note: see current operation: %5 = "mps.reshape"(%arg1, %4) : (tensor<1x64x16x16x16xf32>, tensor<4xsi32>) -> tensor<1x128x16x256xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:307:0: error: 'mps.reshape' op the result shape is not compatible with the input shape
(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:307:0: note: see current operation: %5 = "mps.reshape"(%arg1, %4) : (tensor<1x64x16x16x16xf32>, tensor<4xsi32>) -> tensor<1x128x16x256xf32>
/AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:975: failed assertion `original module failed verification'
zsh: abort      python torchtest.py

When i remove the residual units this runs normally. I’m confused as to why?
I’m also unsure if i should be posting this here as the error stems from torch or in monai (due to the error being throwed from their ResidualUnit).