Training a transformer based model for action detection, returning NaNs after a few epochs

Hi! I’ve ben experimenting with this model for Action Detection, and I modified it to run it on a custom dataset.
The idea is that I’m using a 3D-CNN model (I3D network) to extract features from some videos first. Then, the model (it’s Transformers-based) uses these features and predictions from the previous time steps to form the input sequences for the encoder and decoder.

The problem is that, during training, after a few epochs (specifically, always at epoch 6) the loss becomes NaN.
I checked, and both input features and previous predictions are ok (as in, they don’t contain NaNs) until the call to loss_train.backward(). After that, the weights of a linear encoding layer that is used to reduce the features dimensions become NaNs, and everything breaks from there.

I ran the training again using with torch.autograd.set_detect_anomaly(True):, and now it’s returning the error:

File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/c.demasi/work/projects/ball_shot_action_detection_dev_environment/./external/simon/simon/simon.py", line 105, in forward
    outs = self.transformer(queries, memory, target_mask)
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 460, in forward
    output = mod(output, memory, tgt_mask=tgt_mask,
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 846, in forward
    x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 855, in _sa_block
    x = self.self_attn(x, x, x,
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/home/c.demasi/anaconda3/envs/simon_py310/lib/python3.10/site-packages/torch/nn/functional.py", line 5440, in multi_head_attention_forward
    attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

Any general suggestion?

As a follow-up, after the backward pass is performed, I added the following:

loss.backward()
for name, param in model.named_parameters():
    if param.requires_grad is True and param.grad is None:
        pass

and I see that the param feature_start_emb.embedding.weight has None value.
This presumably comes from this embedding operation:

self.feature_start_emb = TokenEmbedding(1, 256)
feature_start_emb = self.feature_start_emb(torch.tensor([0]).cuda())

where

class TokenEmbedding(nn.Module):
    # helper Module to convert tensor of input indices into corresponding tensor of token embeddings
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

Your code will fail with a device mismatch:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

since the module wasn’t moved to the GPU yet.
After fixing it, the weight shows valid gradients:

feature_start_emb = TokenEmbedding(1, 256).cuda()
out = feature_start_emb(torch.tensor([0]).cuda())

out.mean().backward()
print(feature_start_emb.embedding.weight.grad.abs().sum())
# tensor(16., device='cuda:0')