Binary_cross_entropy crashes when data on mps device

Hello everyone.

I’m currently experimenting with mps accelerator on my m1 pro macbook. Encountered weird behaviour

This code works perfectly fine

device = torch.device("cpu")
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1),
    nn.Sigmoid()
)
model.to(device)
input = torch.randn(5, 10).to(device)
target = torch.randint(0, 2, (5, 1)).float().to(device)

F.binary_cross_entropy(
    model(input),
    target
)

but when I made a single change (“cpu” → “mps”):

device = torch.device("mps")
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 1),
    nn.Sigmoid()
)
model.to(device)
input = torch.randn(5, 10).to(device)
target = torch.randint(0, 2, (5, 1)).float().to(device)

F.binary_cross_entropy(
    model(input),
    target
)

everything crashed.

My interpreter died with the following message

loc("reductionMeanTensor"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/b6051351-c030-11ec-96e9-3e7866fcf3a1/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":281:0)): error: invalid axes: 1
LLVM ERROR: Failed to infer result type(s).
[1]    8876 abort      ipython

Can anyone help me with this problem? Is this operation not supported by mps yet or is it my mistake?

P.s. I’ve encountered this only with binary_cross_entropy operation. I tried some more experiments (for ex. with cross_entropy) and they worked just fine.

My torch version is 1.12.0