RuntimeError: required rank 4 tensor to use channels_last format

My transformer training loop seems to work correctly when I train it on the CPU, but when I switch to MPS, I get the below error when computing loss.backward() for Cross Entropy loss. I am doing machine translation.

model output: (batch_size, seq_len, target_vocab_size)
model output reshaped: (batch_size * seq_len, target_vocab_size)
target: (batch_size, seq_len)
target reshaped: (batch_size * seq_len) 

Terminal output:

out:  torch.Size([64, 60, 1446])
trg:  torch.Size([64, 60])
out reshaped:  torch.Size([3840, 1446])
trg reshaped:  torch.Size([3840])
Traceback (most recent call last):
  File "~/", line 198, in <module>
  File "~/myenv2/lib/python3.9/site-packages/torch/", line 401, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "~/myenv2/lib/python3.9/site-packages/torch/autograd/", line 191, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: required rank 4 tensor to use channels_last format

Here is my code for reference:

mask = torch.tril(torch.ones((MAX_LENGTH, MAX_LENGTH))).to(DEVICE)

# optimization loop 
best_loss = 1e5
loss_fn = torch.nn.CrossEntropyLoss() 
train_losses = []
val_losses = []
for epoch in range(1,EPOCHS+1):

    # train loop 
    for i, (src,trg) in enumerate(train_data):

        # place tensors to device 
        src = torch.Tensor(src).to(DEVICE).long()
        trg = torch.Tensor(trg).to(DEVICE).long()

        # forward pass 
        out = model(src,trg, mask)
        print('out: ', out.size())
        print('trg: ', trg.size())
        print('out reshaped: ', out.view(-1, tgt_vocab).size())
        print('trg reshaped: ', trg.view(-1).size())

        # compute loss 
        train_loss = loss_fn(out.view(-1,tgt_vocab), trg.view(-1))
        # backprop 

        # update weights