Getting NaN values in backward pass (triplet loss function)

Hi there, I’m trying to reimplement SCAN.

My code is here.

The loss didn’t change after each iteration. When I called backward to calculate the gradients, it happened to be NaN.

Do you guys have any suggestions?

Hi,

I’d try Automatic differentiation package - torch.autograd — PyTorch 1.8.0 documentation to which backward part emit nan first.

1 Like

It showed Function 'DivBackward0' returned nan values in its 0th output.

I’m trying to figure out but found nothing :<

@ptrblck Sir please give me a hand!

Oh, it’s a little bit hard to identify which layer.

nan can occur for some reasons but mainly it’s oftentimes 0/inf related maths.
For example, in SCAN code (SCAN/model.py at master · kuanghuei/SCAN · GitHub), nan and inf can happen in forward of l1norm and l2norm. So, I think it’s better to investigate where those bad values are generated, for example, by using forward/backward hooks.

I think one implementation of a backward hook with module name is as follows.
Also, forward hook with similar functionalities will be helpful.

# backward hook with module name
def get_backward_hook(module_name: str):
    
    class BackwardHook:
        name: str
            
        def __init__(self, name):
            self.name = name
            
        def __call__(self, module, grad_input, grad_output):
            for i, g_in in enumerate(grad_input):
                print(module_name, torch.any(torch.isnan(g_in)))
                if torch.any(torch.isnan(g_in)):
                    print(f"{module_name}'s {i}th input gradient is nan")
            for i, g_out in enumerate(grad_output):
                if torch.any(torch.isnan(g_out)):
                    print(f"{module_name}'s {i}th output gradient is nan")
                
    return BackwardHook(module_name)

...

model = SomeCoolModel(...)
for name, module in model.named_modules():
    module.register_full_backward_hook(get_backward_hook(name))

Hope this

1 Like

Thank you for your generous response. I found that there is something wrong in this line of code but I can’t figure out. Could you please have a look at it?!

def forward(self, images: Tensor, captions: Tensor, cap_lens: Tensor) -> Tensor:
        """
        Parameters
        ----------
        images:      (batch, regions, hidden_dim)
        captions:    (batch, cap_lenghts, hidden_dim)
        cap_lengths: (batch)
        Returns
        -------
        sim_scores: (batch cap, batch image)
        """
        # Projecting the image/caption to embedding dimensions
        img_embs = self.img_enc(images)
        cap_embs = self.txt_enc(captions, cap_lens)
        # -> img_embs: (batch, regions, hidden)
        # -> cap_embs: (batch, max_seq_len, hidden)

        # Compute similarity between each region and each token
        # cos(u, v) = (u @ v) / (||u|| * ||v||)
        #           = (u / ||u||) @ (v / ||v||)
        img_embs = F.normalize(img_embs, dim=-1)
        cap_embs = F.normalize(cap_embs, dim=-1)

        img_embs.unsqueeze_(0)
        cap_embs.unsqueeze_(1)
        # -> img_embs: (1, batch, regions, hidden)
        # -> cap_embs: (batch, 1, max_seq_len, hidden)

        # After normalizing: cos(u, v) = u @ v
        sim_token_region = torch.matmul(cap_embs, img_embs.transpose(-1, -2))
        # -> sim_token_region: (batch, batch, max_seq_len, regions)

        att_score = self.lambda_softmax * sim_token_region
        # -> att_score: (batch, batch, max_seq_len, regions)

        # Create mask from caption length
        # torch.arange(max_seq_len_SIZE).expand(batch_SIZE, max_seq_len_SIZE)
        padding_mask = torch.arange(cap_embs.size(2)).expand(
            cap_embs.size(0), cap_embs.size(2)
        ) >= cap_lens.unsqueeze(1)
        # padding_mask: (batch, max_seq_len)
        
        padding_mask = padding_mask.unsqueeze(-1).unsqueeze(1).to(cap_embs.device)
        # -> padding_mask: (batch, 1, max_seq_len, 1)

        # mask score of padding tokens
        att_score.data.masked_fill_(padding_mask, -float('inf'))

        # softmax along regions axis
        attention_weights = F.softmax(att_score, dim=-1)
        # -> attention_weights: (batch, batch, max_seq_len, regions)

        # Calculate weighted sum of regions -> attended image vectors
        attention = torch.matmul(attention_weights, img_embs)
        # -> attention: (batch, batch, max_seq_len, hidden_dim)

        # Calculate the importance of each attended image vector
        # w.r.t each token of sentence
        r = F.cosine_similarity(cap_embs, attention, dim=-1)
        # -> r: (batch, batch, max_seq_len)

        # Zeros similarity of padding tokens
        r[r.isnan()] = 0

        # Calculate similarity of each caption and each image by averaging all tokens in a caption.
        sim_scores = r.sum(dim=-1) / cap_lens.view(-1, 1).to(cap_embs.device)
        # -> sim_score: (batch, batch)

        return sim_scores

Did you find that in forward computation? or backward?

Either way, I’m not sure but I’d check the values of img_embs and backward input values.