I encountered ‘Loss is nan, stopping training’ when training my model with an additional multiheadAttention module. I have checked that when I’m not using the if
block, the training is passing without error. Can anyone spot what’s causing the nan
in this part of the code?
def forward(self, x):
x = self.forward_features(x) # [32, 198, 384]
if self.add_attn: # This if block is causing the `nan`
B, N, C = x.shape # batch_size, sequence_length, embed_dim
x_qkv = self.add_forward(x) # This is nn.Linear(self.embed_dim, 3*self.embed_dim)
x_qkv = x_qkv.reshape(B, N, 3, C).permute(2, 1, 0, 3)
x, _ = self.add_attn(x_qkv[0], x_qkv[1], x_qkv[2]) # This is nn.MultiheadAttention(embed_dim=embed_dim, num_heads=12)
x = x.permute(1, 0, 2)
x = self.GeM(x, axis=1)
x = self.head(x)
return x