Grads are None after backward (broken graph?)

Hi all,

So, based on an existing implementation, I am trying to adapt it to my goal by changing the loss calculation, which is giving None for every grad after backward step.

I believe the graph might be broken, based on similar topics I have read, but I am also in doubt if it has to do with the loss and the way the backward is done.

The model output is the following:

It also returns several losses in a dictionary (dict_loss).

However I want to use crossentropyloss based on each bounding box. As example an image with 15 bounding boxes, each one with one label, and this is what I am using as target in the loss calculation (a tensor of 15 labels).

In order to have the tensor with the class probabilities for each bounding box from the model outputs, I wrote two functions. The first, takes the pred_boxes, pred_classes and scores and builds pred_masks.
The pred_masks have as shape (batch_size, n_classes, height, with).

    def get_pred_mask(self, outputs, n_classes, mask_shape):
        final_masks = []

        for ioutputs in outputs:

            pred_mask_cl = [np.zeros((mask_shape[0],  mask_shape[1])) for icat in range(0,n_classes+1)]
            
            for icl in range(0, len(ioutputs['instances'].pred_classes)):               
                for icat in range(0,n_classes):
                    if ioutputs['instances'].pred_classes[icl]==icat:
                        coord = ioutputs['instances'].pred_boxes[icl].tensor
                        class_coord = (int(coord[0][0]), int(coord[0][1])), (int(coord[0][0]+coord[0][2]), int(coord[0][1]+coord[0][3]))
                        color = int(ioutputs['instances'].pred_classes[icl])
                        pred_mask_cl[icat+1] = cv2.rectangle(pred_mask_cl[icat+1], class_coord[0], class_coord[1], color, -1)

            final_masks.append(pred_mask_cl)

        return torch.tensor(np.stack(final_masks), requires_grad=True, device='cuda')

Then I take this output and use it to get the class probabilities for each bounding box, I go and get the most common class for each bounding box to consider.


    def get_bbox_label(self, y_t, bbox_to_consider, n_classes, first_class=1, original_shape=(1025,1025)): #batch[0]['image'].shape[1:]):
        softmax=nn.Softmax(dim=1)
        y_score=y_t 
        y_=y_t.argmax(1)
        final_labels = []
        scale = y_[0].shape[0]/original_shape[0]
        for i in range(0,len(y_)):
            img_id=str(uuid.uuid4())
            bbox_group_all=[]
            bbox_group_score_all=[]
            for ixb in bbox_to_consider[i]:
                class_coord = ((int(ixb[0]*scale), int(ixb[1]*scale)), (int(ixb[2]*scale), int(ixb[3]*scale)))
                bbox_group = y_[i][int(ixb[1]*scale):int(ixb[3]*scale), int(ixb[0]*scale):int(ixb[2]*scale)]  
                bbox_group_score = y_score[i][first_class:,int(ixb[1]*scale):int(ixb[3]*scale), int(ixb[0]*scale):int(ixb[2]*scale)]
                bbox_group = Counter(bbox_group.reshape(bbox_group.shape[0]*bbox_group.shape[1]).tolist()).most_common(1)[0][0]
                bbox_group_score = bbox_group_score.reshape(bbox_group_score.shape[0], bbox_group_score.shape[1]*bbox_group_score.shape[2])
                bbox_group_all.append(bbox_group if bbox_group!=0 else 1) #force that if no label it is text
                bbox_group_score_all.append([np.round(i.item(),2) for i in bbox_group_score.max(1).values.reshape(1, n_classes)[0]])
                
            final_labels.append(pd.DataFrame({'bbox_label': [bbox_group_all], 'bbox_group_score': [bbox_group_score_all], 'img': img_id}))

        final_labels = pd.concat(final_labels).reset_index(drop=True)
        pred_label_bbox = torch.tensor(np.concatenate([item for sublist in final_labels.bbox_group_score.tolist() for item in sublist]).reshape(50,5), requires_grad=True, device='cuda')

        
        return(pred_label_bbox)

Finally, I calculate the loss (both tensors used in this calculation are leaf tensors):

c_losses = loss_fn(pred_label_bbox.reshape(len(target), n_classes), target)

The logic that follows is:

grad_scaler = GradScaler()
(...)
grad_scaler.scale(c_losses).backward()
grad_scaler.step(optimizer)
grad_scaler.update()    

Where I get:

File /usr/local/lib/python3.8/dist-packages/torch/cuda/amp/grad_scaler.py:343, in GradScaler.step(self, optimizer, *args, **kwargs)
    341 print("debug...")
    342 print(optimizer_state["found_inf_per_device"])
--> 343 assert len(optimizer_state["found_inf_per_device"]) > 0, "No inf checks were recorded for this optimizer."
    345 retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
    347 optimizer_state["stage"] = OptState.STEPPED

AssertionError: No inf checks were recorded for this optimizer.

PROBLEM: And from here I found out the the grads are None!
However, if I use the original implementation loss, everything works well.

The differences I see:

  1. Original implementation:
sum(dict_loss.values())
> tensor(5.3273, device='cuda:0', grad_fn=<AddBackward0>)

grad_scaler.scale(sum(dict_loss.values()))
> tensor([349132.5312], device='cuda:0', grad_fn=<MulBackward0>)
  1. Mine:
loss
> tensor(1.7942, device='cuda:0', dtype=torch.float64, grad_fn=<NllLossBackward0>)

grad_scaler.scale(loss)
> tensor([117585.7969], device='cuda:0', grad_fn=<MulBackward0>)

Thank you in advance for your help. Any insight will be much appreciated!

You are detaching the computation graph by rewrapping the tensors as seen e.g. here:

return torch.tensor(np.stack(final_masks), requires_grad=True, device='cuda')

pred_label_bbox = torch.tensor(np.concatenate([item for sublist in final_labels.bbox_group_score.tolist() for item in sublist]).reshape(50,5), requires_grad=True, device='cuda')

These operations create a new leaf tensor without any Autograd history which also means that previous operations used to create the internal tensor are detached.

Try to reuse the tensor directly and stick to PyTorch operations to avoid detaching the computation graph.

1 Like

PROBLEM: Even though some grads are no longer None, this is not the case for all grads.
So I no longer have the error (since some grads have now values) but i still have a problem in train, since after the first backward step, some grads are updated to None I think and every prediction is the same.

Do you have any idea of what might be the problem??? Maybe because I am not using all the model outputs to calculate the loss? does this makes sense??

I am now using this to transform model output:

 def get_bbox_pred_mask(self, batch, outputs, n_classes):
        img_shape = batch[0]['image'].shape[1:]
        mask_shape = batch[0]['mask'].shape[1:]
        scale = mask_shape[0]/img_shape[0]
        final_masks = []
        final_bbox_pred_proba = []

        for ix in range(0, len(outputs)):
            ioutputs=outputs[ix]

            pred_mask_cl = [torch.zeros((mask_shape[0],  mask_shape[1]), requires_grad=True, device='cuda') for icat in range(0,n_classes+1)]
            pred_icl = ioutputs['instances'].pred_classes
            for iclix in range(0,len(pred_icl)):
                cl = pred_icl[iclix]
                pred_mask_cl[cl+1] = pred_mask_cl[cl+1] + ioutputs['instances'].pred_masks[iclix]*ioutputs['instances'].scores[iclix]
            final_masks.append(pred_mask_cl) 

            bbox_pred_proba = []
            for ixbix in range(0, len(batch[ix]['instances'].gt_boxes.tensor)):
                ixb = batch[ix]['instances'].gt_boxes.tensor[ixbix]
                gt_cl = batch[ix]['instances'].gt_classes[ixbix]
                bbox_pred_proba.append(torch.stack([imask[int(ixb[1]*scale):int(ixb[3]*scale), int(ixb[0]*scale):int(ixb[2]*scale)].mean() for imask in pred_mask_cl][1:]))
            final_bbox_pred_proba.append(torch.stack(bbox_pred_proba))
        
        return(final_bbox_pred_proba, final_masks)

If some parameters are still getting a None gradient (i.e. no gradient is ever set) it would mean that these parameters are either not used to compute the loss or the computation graph including these parameters is still being detached. Check how and where these parameters are supposedly used and check how they influence the loss calculation.

1 Like