Loss is nan, stopping training in MultiheadAttention

I encountered ‘Loss is nan, stopping training’ when training my model with an additional multiheadAttention module. I have checked that when I’m not using the if block, the training is passing without error. Can anyone spot what’s causing the nan in this part of the code?

    def forward(self, x):
        x = self.forward_features(x)  # [32, 198, 384]
        if self.add_attn:  # This if block is causing the `nan`
            B, N, C = x.shape  # batch_size, sequence_length, embed_dim
            x_qkv = self.add_forward(x)  # This is nn.Linear(self.embed_dim, 3*self.embed_dim)
            x_qkv = x_qkv.reshape(B, N, 3, C).permute(2, 1, 0, 3)
            x, _ = self.add_attn(x_qkv[0], x_qkv[1], x_qkv[2])  # This is nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12)
            x = x.permute(1, 0, 2)
        x = self.GeM(x, axis=1)
        x = self.head(x)
        return x

Hey @xincz
Could you please share a colab notebook of sorts where we could run the models and try debugging them?

With a code snippet, it is not very easy to debug or point to the issue for that matters.