Assertion failed: Runtime canonicalization must simplify reduction axes to minor 4 dimensions

Hello,

I’m trying to train a model with the AlexNet architecture. I’m able to train it on a cuda device using google colab but when I try to train it on my mac using mps, I get the following error:

Assertion failed: (0 <= mpsAxis && mpsAxis < 4 && "Runtime canonicalization must simplify reduction axes to minor 4 dimensions."), function encodeNDArrayOp, file GPUReductionOps.mm, line 76.

Python would crash and my kernel would die, so I’m at a loss on how to debug this. I would greatly appreciate any help or hints to point me in the right direction for fixing this.