I am trying to speed up STN sample which is in
https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html
I have install nighty version with conda
$ python
Python 3.9.12 (main, Apr 5 2022, 01:52:34)
[Clang 12.0.0 ] :: Anaconda, Inc. on darwin
Type “help”, “copyright”, “credits” or “license” for more information.
import torch
import platform
print(platform.platform())
macOS-12.3.1-arm64-arm-64bit
torch.version
‘1.13.0.dev20220903’
torch.device(‘mps’)
device(type=‘mps’)
I modified the code
#device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
speed up for Apple M1
device = torch.device(‘mps’)
however, the kernel is crash
-:27:11: error: invalid input tensor shapes, indices shape and updates shape must be equal
-:27:11: note: see current operation: %25 = “mps.scatter_along_axis”(%23, %arg0, %24, %1) {mode = 6 : i32} : (tensor<150528xf32>, tensor<28xf32>, tensor<50176xi32>, tensor) → tensor<150528xf32>
/AppleInternal/Library/BuildRoots/560148d7-a559-11ec-8c96-4add460b61a6/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1267: failed assertion `Error: MLIR pass manager failed’