Gradient for different parts of the loss is None even if `.retain_grad()` is called

Hi folks,

I’m in training auto-encoder on 3D objects and I have such a piece of code, to calculate the loss for the forward pass.

    def forward(self, in_pc_batch, iteration, frame, t_idx, t_nor, faces, bDebug = False):

        nbat = in_pc_batch.size(0)
        npt = in_pc_batch.size(1)
        nch = in_pc_batch.size(2)


        t_mu, t_logstd = self.net_geoenc(in_pc_batch, self.mcvcoeffsenc) # in in mm and out in dm
        t_std = t_logstd.exp()

        t_eps = torch.ones_like(t_std).normal_() #torch.FloatTensor(t_std.size()).normal_().to(device)
        t_z = t_mu + t_std * t_eps


        klloss = torch.mean(-0.5 - t_logstd + 0.5 * t_mu ** 2 + 0.5 * torch.exp(2 * t_logstd))
        out_pc_batchfull = self.net_geodec(t_z, self.mcvcoeffsdec)
        out_pc_batch = out_pc_batchfull[:,:,0:3]


        dif_pos = out_pc_batch - in_pc_batch

        faces_long = faces.long()
        vet0 = index_selection_nd(dif_pos, faces[:, :, 0], 1)
        vet1 = index_selection_nd(dif_pos, faces[:, :, 1], 1)
        vet2 = index_selection_nd(dif_pos, faces[:, :, 2], 1)

        loss_normal = ((vet1 + vet0 + vet2)/3.0 * t_nor).sum(2).pow(2).mean()

        loss_pose_l1 = self.net_loss.compute_geometric_loss_l1(in_pc_batch[:,:,0:3], out_pc_batch[:,:,0:3])
        loss_laplace_l1 = self.net_loss.compute_laplace_loss_l2(in_pc_batch[:,:,0:3], out_pc_batch[:,:,0:3])

        loss = loss_pose_l1*self.w_pose +  loss_laplace_l1 * self.w_laplace  + klloss * self.klweight + loss_normal * self.w_nor

But, when I try to analyze the gradients of different losses, it gives me None, even if I explicitly call .retain_grad() on it:

        loss,loss_pose_l1,loss_laplace_l1, klloss, loss_normal  = net_autoenc(pc_batch, iteration,frame, t_idx, t_nor, faces, True)

        optimizer.zero_grad()
        for l in [loss_laplace_l1, loss_pose_l1, klloss, loss_normal]:
            l.retain_grad()

        loss.register_hook(lambda grad: print("Loss_grad: ", grad))
        loss_pose_l1.register_hook(lambda grad: print("Loss_pos_l1_grad: ", grad))

        loss.backward(torch.ones(loss.size(0), device=device))

        print("Loss_pose_l1_grad: ", loss_pose_l1.grad)
        print("Loss_laplace_l1_grad: ", loss_laplace_l1.grad)
        print("klloss_grad: ", klloss.grad)
        print("Loss_normal_grad: ", loss_normal.grad)
        optimizer.step()

The output of this is:

Loss_grad:  tensor([1.], device='cuda:0')
Loss_pose_l1_grad:  None
Loss_laplace_l1_grad:  None
klloss_grad:  None
Loss_normal_grad:  None

When I tried to interactively check the tensors:

ipdb> loss_pose_l1
tensor([0.0060], device='cuda:0', grad_fn=<UnsqueezeBackward0>)
ipdb> loss_pose_l1.grad
ipdb> loss_pose_l1.requires_grad
True
ipdb> loss_pose_l1.is_leaf
False
ipdb> loss_pose_l1.retains_grad
True

Could someone please help me, to understand, why the gradients are None?

My PyTorch version is '1.9.0+cu111'.

Thanks in advance!

Assuming you are returning the summed loss as well as all separate losses in your forward method, it seems to work for me:

def fun(model, data):
    out = model(data)
    loss_pose_l1 = out.mean()
    loss_laplace_l1 = out.sum()
    klloss = F.mse_loss(out, torch.rand_like(out))
    loss_normal = (out**2).mean()    
    
    loss = loss_pose_l1 + loss_laplace_l1 + klloss + loss_normal
    return loss, loss_pose_l1, loss_laplace_l1, klloss, loss_normal


model = nn.Linear(1, 1)
data = torch.randn(1, 1)
loss, loss_pose_l1, loss_laplace_l1, klloss, loss_normal = fun(model, data)

for l in [loss_laplace_l1, loss_pose_l1, klloss, loss_normal]:
    l.retain_grad()
    
loss.backward()
print("Loss_pose_l1_grad: ", loss_pose_l1.grad)
print("Loss_laplace_l1_grad: ", loss_laplace_l1.grad)
print("klloss_grad: ", klloss.grad)
print("Loss_normal_grad: ", loss_normal.grad)

so could you post an executable code snippet to reproduce the issue?

Hi ptrblck,
thanks a lot for the reply!

Yes, I already tested on such a simple block and it works, so I was super confused about what’s wrong in my case.

The only thing, that finally worked for me was to put .register_hook inside the forward method:

    def forward(self, in_pc_batch, iteration, frame, t_idx, t_nor, faces, bDebug = False):

        nbat = in_pc_batch.size(0)
        npt = in_pc_batch.size(1)
        nch = in_pc_batch.size(2)


        t_mu, t_logstd = self.net_geoenc(in_pc_batch, self.mcvcoeffsenc) # in in mm and out in dm
        t_std = t_logstd.exp()

        t_eps = torch.ones_like(t_std).normal_() 
        t_z = t_mu + t_std * t_eps

        t_z.register_hook(lambda grad: print("t_z_grad: ", grad, ", shape: ", grad.shape))

        klloss = torch.mean(-0.5 - t_logstd + 0.5 * t_mu ** 2 + 0.5 * torch.exp(2 * t_logstd))
        out_pc_batchfull = self.net_geodec(t_z, self.mcvcoeffsdec)
        out_pc_batch = out_pc_batchfull[:,:,0:3]
        dif_pos = out_pc_batch - in_pc_batch

        faces_long = faces.long()
        vet0 = index_selection_nd(dif_pos, faces[:, :, 0], 1)
        vet1 = index_selection_nd(dif_pos, faces[:, :, 1], 1)
        vet2 = index_selection_nd(dif_pos, faces[:, :, 2], 1)
        loss_normal = ((vet1 + vet0 + vet2)/3.0 * t_nor).sum(2).pow(2).mean()

        loss_pose_l1 = self.net_loss.compute_geometric_loss_l1(in_pc_batch[:,:,0:3], out_pc_batch[:,:,0:3])
        loss_laplace_l1 = self.net_loss.compute_laplace_loss_l2(in_pc_batch[:,:,0:3], out_pc_batch[:,:,0:3])

        loss_pose_l1.register_hook(lambda grad: print("Loss_pos_l1_grad: ", grad))
        loss_laplace_l1.register_hook(lambda grad: print("Loss_laplace_l1_grad: ", grad))
        klloss.register_hook(lambda grad: print("KLloss_grad: ", grad))
        loss_normal.register_hook(lambda grad: print("Loss_normal_grad: ", grad))

        loss = loss_pose_l1*self.w_pose + loss_laplace_l1 * self.w_laplace + klloss * self.klweight + loss_normal * self.w_nor

        return loss[None],loss_pose_l1[None],loss_laplace_l1[None], klloss[None], loss_normal[None]

In this case, it prints right values (weights).
In the opposite case, even if I call retain_grad or call register hook outside forward function it doesn’t call a hook our output None.

And to clarify,

        loss,loss_pose_l1,loss_laplace_l1, klloss, loss_normal  = net_autoenc(pc_batch, iteration,frame, t_idx, t_nor, faces, True)

net_autoenc is created in such way:

net_autoenc = Net_autoenc(param, facedata)
net_autoenc = nn.DataParallel(net_autoenc, device_ids=device_ids).to(device)

If you need to have any extra information, please ask!
Cheers