Hello I’ve been tinkering with torch on macOS (M3 Max 16cores, 128GB RAM, 40 gpu cores) on sequoia 15.3, torch version is a fork i installed from here, @ffda73cfecffd15b50a85745e98d0641a5583b58, due to transposed convolutions not working on the current stable & night branches. AFAIK it’s forked from torch-2.3.0.
I tried running a simple UNet with residual units (monai implementations 1.4.0). Running initially on jupyter the kernel would die without a reason. Then i wrote a minimum example on a script:
import torch
from monai.networks.nets import UNet
from torchsummary import summary
device = torch.device('mps')
q = 128
a = torch.rand((1,1,q,q,q)).to(device)
b = torch.rand((1,1,q,q,q)).to(device)
model = UNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
channels=(16, 32, 64, 128),
strides=(2, 2, 2),
num_res_units=2,
)
# print(summary(model, torch.rand((1,q,q,q)).shape))
model = model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.MSELoss()
output = model(a)
loss = criterion(output, b)
print(loss.item())
loss.backward()
print('checkpoint')
This in return throws this error:
(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:307:0: error: 'mps.reshape' op the result shape is not compatible with the input shape
(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:307:0: note: see current operation: %5 = "mps.reshape"(%arg1, %4) : (tensor<1x64x16x16x16xf32>, tensor<4xsi32>) -> tensor<1x128x16x256xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:307:0: error: 'mps.reshape' op the result shape is not compatible with the input shape
(mpsFileLoc): /AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:307:0: note: see current operation: %5 = "mps.reshape"(%arg1, %4) : (tensor<1x64x16x16x16xf32>, tensor<4xsi32>) -> tensor<1x128x16x256xf32>
/AppleInternal/Library/BuildRoots/d187755d-b9a3-11ef-83e5-aabfac210453/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:975: failed assertion `original module failed verification'
zsh: abort python torchtest.py
When i remove the residual units this runs normally. I’m confused as to why?
I’m also unsure if i should be posting this here as the error stems from torch or in monai (due to the error being throwed from their ResidualUnit).