Loss.backward throws "mps.concat" error on M1, but works well on CPU and GPU

Hi!
I’ve been trying to run a Seq2Seq model on my M1 mac, but for some reason I’m getting an error while running loss.backward().
The error:

(mpsFileLoc): /AppleInternal/Library/BuildRoots/0aa643d0-625a-11ed-b319-a23c4f261b56/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:636:0: error: invalid input tensor shapes, all input shapes must match except at axis
(mpsFileLoc): /AppleInternal/Library/BuildRoots/0aa643d0-625a-11ed-b319-a23c4f261b56/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:636:0: note: see current operation: %28 = "mps.concat"(%26, %24, %27) : (tensor<3x1x8xf32>, tensor<1x4x8xf32>, tensor<si32>) -> tensor<4x1x8xf32>

I can see the mismatch between 3x1x8xf32 and 1x4x8xf32, but it isn’t present anywhere in the code, and it works well outside MPS.

Code:

class OneLayerSeq2Seq(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(OneLayerSeq2Seq, self).__init__()
        self.hidden_size = hidden_size
        self.encoder = torch.nn.LSTM(input_size, hidden_size, batch_first=True)
        self.decoder = torch.nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(hidden_size, output_size)

    def forward(self, input, future=0):
        outputs = []
        encoder_output, (hidden, cell) = self.encoder(input)
        encoder_output = self.relu(encoder_output)

        decoder_input = encoder_output[:, -1, :].unsqueeze(1)
        for i in range(future):
            decoder_output, (hidden, cell) = self.decoder(decoder_input, (hidden, cell))
            decoder_input = self.relu(decoder_output)
            outputs += [self.linear(decoder_input)]
        outputs = torch.cat(outputs, dim=1)
        return outputs

Do you have any idea what might be causing it?

1 Like