My transformer training loop seems to work correctly when I train it on the CPU, but when I switch to MPS, I get the below error when computing loss.backward()
for Cross Entropy loss. I am doing machine translation.
Dimensions:
model output: (batch_size, seq_len, target_vocab_size)
model output reshaped: (batch_size * seq_len, target_vocab_size)
target: (batch_size, seq_len)
target reshaped: (batch_size * seq_len)
Terminal output:
out: torch.Size([64, 60, 1446])
trg: torch.Size([64, 60])
out reshaped: torch.Size([3840, 1446])
trg reshaped: torch.Size([3840])
Traceback (most recent call last):
File "~/transformer_training.py", line 198, in <module>
train_loss.backward()
File "~/myenv2/lib/python3.9/site-packages/torch/_tensor.py", line 401, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "~/myenv2/lib/python3.9/site-packages/torch/autograd/__init__.py", line 191, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: required rank 4 tensor to use channels_last format
Here is my code for reference:
mask = torch.tril(torch.ones((MAX_LENGTH, MAX_LENGTH))).to(DEVICE)
# optimization loop
best_loss = 1e5
optimizer=torch.optim.Adam(params=model.parameters(),lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss()
train_losses = []
val_losses = []
for epoch in range(1,EPOCHS+1):
# train loop
for i, (src,trg) in enumerate(train_data):
# place tensors to device
src = torch.Tensor(src).to(DEVICE).long()
trg = torch.Tensor(trg).to(DEVICE).long()
# forward pass
out = model(src,trg, mask)
print('out: ', out.size())
print('trg: ', trg.size())
print('out reshaped: ', out.view(-1, tgt_vocab).size())
print('trg reshaped: ', trg.view(-1).size())
# compute loss
train_loss = loss_fn(out.view(-1,tgt_vocab), trg.view(-1))
# backprop
optimizer.zero_grad()
train_loss.backward()
# update weights
optimizer.step()