BCELoss error with mps

I’m using BCELoss. But when I use mps, it gives me the following error. It runs well without mps.

loc(“reductionMeanTensor”("(mpsFileLoc): /AppleInternal/Library/BuildRoots/560148d7-a559-11ec-8c96-4add460b61a6/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":281:0)): error: invalid axes: 2
LLVM ERROR: Failed to infer result type(s).