RuntimeError: required rank 4 tensor to use channels_last format

My transformer training loop seems to work correctly when I train it on the CPU, but when I switch to MPS, I get the below error when computing loss.backward() for Cross Entropy loss. I am doing machine translation.
Dimensions:

model output: (batch_size, seq_len, target_vocab_size)
model output reshaped: (batch_size * seq_len, target_vocab_size)
target: (batch_size, seq_len)
target reshaped: (batch_size * seq_len) 

Terminal output:

out:  torch.Size([64, 60, 1446])
trg:  torch.Size([64, 60])
out reshaped:  torch.Size([3840, 1446])
trg reshaped:  torch.Size([3840])
Traceback (most recent call last):
  File "~/transformer_training.py", line 198, in <module>
    train_loss.backward()
  File "~/myenv2/lib/python3.9/site-packages/torch/_tensor.py", line 401, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "~/myenv2/lib/python3.9/site-packages/torch/autograd/__init__.py", line 191, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: required rank 4 tensor to use channels_last format

Here is my code for reference:

mask = torch.tril(torch.ones((MAX_LENGTH, MAX_LENGTH))).to(DEVICE)

# optimization loop 
best_loss = 1e5
optimizer=torch.optim.Adam(params=model.parameters(),lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss() 
train_losses = []
val_losses = []
for epoch in range(1,EPOCHS+1):

    # train loop 
    for i, (src,trg) in enumerate(train_data):

        # place tensors to device 
        src = torch.Tensor(src).to(DEVICE).long()
        trg = torch.Tensor(trg).to(DEVICE).long()

        # forward pass 
        out = model(src,trg, mask)
        print('out: ', out.size())
        print('trg: ', trg.size())
        print('out reshaped: ', out.view(-1, tgt_vocab).size())
        print('trg reshaped: ', trg.view(-1).size())


        # compute loss 
        train_loss = loss_fn(out.view(-1,tgt_vocab), trg.view(-1))
        
        # backprop 
        optimizer.zero_grad()
        train_loss.backward()

        # update weights 
        optimizer.step()