Gradient computation in custom backward

Hi @albanD , thank you for reaching out.

I want to get these:
grad_weight += cont_loss_weight and grad_bias += cont_loss_bias.

For this I have two ideas in my mind, however i don’t know which one is correct (If you can suggest!).
Idea-1:

# keeping remaining code unchanged
if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)  #shape = ([128,96,5,5])
            grad_feat = torch.autograd.grad(cont_loss, feat)         # has to be of shape = ([[batch_size= 100,128, 4, 4])
            cont_loss_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_feat )                                                                     
            grad_weight += cont_loss_weight

# But not sure how to get `cont_loss_bias` in same manner.

Idea-2:

# keeping remaining code unchanged
if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)  #shape = ([128,96,5,5])
            cont_loss_weight = torch.autograd.grad(outputs= cont_loss, inputs= weight, retain_graph=(True))
            grad_weight += cont_loss_weight
            
if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum((0,2,3))        #shape = ([128])
            cont_loss_bias = torch.autograd.grad(outputs= cont_loss, inputs= bias, retain_graph=(True))
            grad_bias += cont_loss_bias          
         

Here’s the code :

class Custom_Convolution(torch.autograd.Function):    
    
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding):  #input(from previous layer)'s shape = ([batch_size=100, 96, 8, 8])
        output = torch.nn.functional.conv2d(input, weight, bias, stride, padding)  
        ctx.save_for_backward(input, weight, bias, output)
        return output    #output's shape = ([[batch_size= 100,128, 4, 4])

    @staticmethod
    def backward(ctx, grad_output):    # grad_output size = ([batch_size, 128,4,4])
        
        input, weight, bias, output = ctx.saved_tensors    #input size = ([batch_size, 96,8,8])
      
        feat = output.clone()#.requires_grad_(True)   # output from forward with size = ([batch_size, 128,4,4])

        feat = feat.view(feat.shape[0], feat.shape[1], -1) # features size = ([batch_size, 128,16])

        i = 0        
        cont = torch.tensor([0.]).to(dev)
        while i in range(0, feat.shape[0]):
              for f in range(len(feat[i])):
                  Zi_unnormalized = feat[i][f]
                  Zi = torch.nn.functional.normalize(Zi_unnormalized, dim = 0)
                  # Zj and Zk are tensors made from feat[i][*] and feat[other than i][*]. Zj and Zk varies for each Zi (or f)

                  Zi_Zk = torch.Tensor([0]).to(dev)
                  for k in Zk:
                      k= torch.nn.functional.normalize(k, dim = 0)
                      zi_zk = ...
                      Zi_Zk = Zi_Zk.add(zi_zk)

                  # Similarly computing Zi_Zj
                  # Li = some algebra of Zi_Zj and Zi_Zk
                  # number of 'Li' values =  feat.shape[0] * feat.shape[1]
                  cont = cont.add(Li)   # 1 value
              i+=1

        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output) #shape = ([batch_size,96,8,8])
        
        # If I go with Idea-2        
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)  #shape = ([128,96,5,5])
            cont_loss_weight = torch.autograd.grad(outputs= cont_loss, 
                                            inputs= weight, retain_graph=(True))
            grad_weight += cont_loss_weight
            
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum((0,2,3))        #shape = ([128])
            cont_loss_bias = torch.autograd.grad(outputs= cont_loss, inputs= bias, retain_graph=(True))
           grad_bias += cont_loss_bias          
            
        if bias is not None:
            return grad_input, grad_weight, grad_bias, None, None
        else:
            return grad_input, grad_weight, None, None
 

It gives:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

If I put feat = output.clone().requires_grad_(True), it gives :

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.