I encountered ‘Loss is nan, stopping training’ when training my model with an additional multiheadAttention module. I have checked that when I’m not using the if block, the training is passing without error. Can anyone spot what’s causing the nan in this part of the code?
def forward(self, x):
x = self.forward_features(x) # [32, 198, 384]
if self.add_attn: # This if block is causing the `nan`
B, N, C = x.shape # batch_size, sequence_length, embed_dim
x_qkv = self.add_forward(x) # This is nn.Linear(self.embed_dim, 3*self.embed_dim)
x_qkv = x_qkv.reshape(B, N, 3, C).permute(2, 1, 0, 3)
x, _ = self.add_attn(x_qkv[0], x_qkv[1], x_qkv[2]) # This is nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12)
x = x.permute(1, 0, 2)
x = self.GeM(x, axis=1)
x = self.head(x)
return x