STN training crash in Apple M1

I am trying to speed up STN sample which is in
https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

I have install nighty version with conda
$ python

Python 3.9.12 (main, Apr 5 2022, 01:52:34)

[Clang 12.0.0 ] :: Anaconda, Inc. on darwin

Type “help”, “copyright”, “credits” or “license” for more information.

import torch

import platform

print(platform.platform())

macOS-12.3.1-arm64-arm-64bit

torch.version

‘1.13.0.dev20220903’

torch.device(‘mps’)

device(type=‘mps’)

I modified the code
#device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

speed up for Apple M1

device = torch.device(‘mps’)

however, the kernel is crash
-:27:11: error: invalid input tensor shapes, indices shape and updates shape must be equal

-:27:11: note: see current operation: %25 = “mps.scatter_along_axis”(%23, %arg0, %24, %1) {mode = 6 : i32} : (tensor<150528xf32>, tensor<28xf32>, tensor<50176xi32>, tensor) → tensor<150528xf32>

/AppleInternal/Library/BuildRoots/560148d7-a559-11ec-8c96-4add460b61a6/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1267: failed assertion `Error: MLIR pass manager failed’