Gradient computation in custom backward

Well, that loss is specific to that layer only. I don’t have concrete argument for that but I was suggested to implement the loss within layer backward only. So, I have to do it that way (No option !:zipper_mouth_face:)

But if that loss term is only influenced by that layers weight, the gradient corresponding to that part of the loss will only influence that layer’s weights.

Also you still haven’t shared any formula of what you’re trying to compute so it is hard to say :confused:

Yes, you are right. It will only influence that layer’s weights.

Sorry, you can see equation of loss in equation no. (2) & (3) and on page no. 13 the gradient computation.
https://arxiv.org/pdf/2004.11362.pdf

For such problem, it will definitely be much easier to let the autograd figure out the gradients. In particular because you might not have all the information you need during the backward of that layer as the loss will depend on the layers that appeared after this one.

Note as well that a good ressource is other implementations of contrastive losses for pytorch: https://github.com/topics/contrastive-loss

As far as I saw other implementations, I found that a seperate loss function is created as you mentioned. But those implementations have two seperate networks one for contrastive loss and then fine tuning with the other one for classification task by CE loss.

But in my case I need to proceed with only one network with contrastive loss included in backward of one layer so that the network learns the features of that particular layer and should update its weight accordingly. That’s why there are these terms in backward:

grad_weight += cont_loss_weight and grad_bias += cont_loss_bias

So, I have to stick to this approach :hugs: but I haven’t seen such implementation or thread regarding this. :man_shrugging:

In doing so, this will happen as mentioned by @ptrblck. So, the layers before the custom layers will also have gradients of contrastive loss accumulated in addition to that of CE loss. But, I want to restrict this gradient accumulation (of contrastive + CE loss) within the custom layer.

Probably, this would justify why I am doing it this way. :smile:

In that case, you can use the nightly version of pytorch an use the new inputs argument to the .backward() function:

net.zero_grad()
ce_loss.backward()
additional_loss.backward(inputs=net.your_contrastive_layer.parameters())
opt.step()
1 Like

So for doing this you mean, I should create a seperate Contrastive loss function instead of doing it in layer backward. And then use nightly version code that you mentioned. Right?

And can you please explain what this code will do. additional_loss.backward(inputs=net.your_contrastive_layer.parameters())

I try to install nightly version with from here.
But I encounterd an error this error:

EnvironmentNotWritableError: The current user does not have write permissions to the target environment.
  environment location: C:\ProgramData\Anaconda3

Looks like I don’t have permission from the admin of the PC I’m using.

So for doing this you mean, I should create a seperate Contrastive loss function instead of doing it in layer backward.

I think it is going to be simpler than modifying the backward and writing the gradients yourself yes.

And can you please explain what this code will do.

It will run the backward as usual but will only update the .grad fields of the inputs that were given. So in your case, you only want to update the parameters of that one layer.

Looks like I don’t have permission from the admin of the PC I’m using.

You might want to create a new environment in your conda so that you can install things.

Thanks for the explanation. :hugs:

I will have to look into this.

But, if I stick to this code above. What do you think is the correct way to calculate cont_loss_weight by Idea-1 or Idea-2?

And what changes should I make in the code of backward to run it successfully, any suggestions?

I tried this as well without nightly version, but GPU went out of memory. :persevere:

RuntimeError: CUDA out of memory.

But, if I stick to this code above. What do you think is the correct way to calculate cont_loss_weight by Idea-1 or Idea-2?

I honestly don’t know. You will need to derive the mathematical formula for what the gradient should be with pen and paper first. Then implement the final formula you get in there.

I tried this as well without nightly version, but GPU went out of memory. :persevere:

Does it run out at the first iteration? Or after a while?
Can you try reducing the batch size to reduce memory pressure?

In the research paper that I showed you, there is gradient equation (i.e. differentiation of Li w.r.t. Zi for every 'i’th feature if you compare with my code), but I don’t know why it was w.r.t. evey feature instead of w.r.t. parameters (weight and bias). Contradictorily, the loss formula doesn’t include use of parameters, but the features which go into the loss formula are result of mathematics of input and parameters!

If I use smaller batch size then I could see results for first few batches, but I doubt that it would still go out of memory after some more iterations.
However, this loss function yields good results for bigger batch sizes. So, I tried with bigger batch size,but it went out of memory after some time without showing first result.

I’m witnessing some unusual output for below code, can you have a look please!

class Custom_Convolution(torch.autograd.Function):    
    
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding):  #input(from previous layer)'s shape = ([batch_size=100, 96, 8, 8])
        output = torch.nn.functional.conv2d(input, weight, bias, stride, padding)  
        ctx.save_for_backward(input, weight, bias, output)
        return output    #output's shape = ([[batch_size= 100,128, 4, 4])

    @staticmethod
    def backward(ctx, grad_output):    # grad_output size = ([batch_size, 128,4,4])
        
        input, weight, bias, output = ctx.saved_tensors    #input size = ([batch_size, 96,8,8])  
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output) #shape = ([batch_size,96,8,8])
              
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)  #shape = ([128,96,5,5])
                        
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum((0,2,3))        #shape = ([128])          
           
        with torch.enable_grad():
               feat = output.clone()   # output from forward with size = ([batch_size, 128,4,4])

               feat = feat.view(feat.shape[0], feat.shape[1], -1) # features size = ([batch_size, 128,16])
       
               cont = torch.tensor([0.]).to(dev)
               for i in range(0, feat.shape[0]):
                   for f in range(len(feat[i])):
                       Zi_unnormalized = feat[i][f]
                       Zi = torch.nn.functional.normalize(Zi_unnormalized, dim = 0)
                       # Zj and Zk are tensors made from feat[i][*] and feat[other than i][*]. Zj and Zk varies for each Zi (or f)

                       Zi_Zk = torch.Tensor([0]).to(dev)
                       for k in Zk:
                           k= torch.nn.functional.normalize(k, dim = 0)
                           zi_zk = ...
                           Zi_Zk = Zi_Zk.add(zi_zk)

                       # Similarly computing Zi_Zj
                       # Li = some algebra of Zi_Zj and Zi_Zk
                       # number of 'Li' values =  feat.shape[0] * feat.shape[1]
                       cont = cont.add(Li)   # 1 value
               print("\n Loss: ", cont_loss, cont_loss.requires_grad)

### This line of printing loss keeps on repeating with the same value of Loss!!!! 


        cont_loss_weight = torch.autograd.grad(outputs= cont_loss, inputs= weight, retain_graph=(True))
        print ("Shape:", cont_loss_weight .shape)                                   
        grad_weight += cont_loss_weight

        cont_loss_bias = torch.autograd.grad(outputs= cont_loss, inputs= bias, retain_graph=(True))
        grad_bias += cont_loss_bias

        if bias is not None:
            return grad_input, grad_weight, grad_bias, None, None
        else:
            return grad_input, grad_weight, None, None

Output :

Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
.
.
.
RuntimeError: CUDA out of memory.

I don’t know why this line is repeating so many time and has no end. And finally CUDA goes out of memory. It should print once only for 1 batch, moreover I haven’t made any indentation mistake!

This line
cont_loss_weight = torch.autograd.grad(outputs= cont_loss, inputs= weight, retain_graph=(True))
is getting executed but it neither show any error nor it returns something because its following line: print ("Shape:", cont_loss_weight .shape) doesn’t get printed.

Hi @albanD,

When I tried to include autograd.grad in backward as above, autograd.grad wasn’t returning anything though is was getting executed. I don’t know why, can you please have a look!

Then I tried it with different approach :

class Custom_Convolution(torch.autograd.Function):    
    
    @staticmethod
    def forward(ctx, input, weight, bias, stride, padding):  #input(from previous layer)'s shape = ([batch_size=100, 96, 8, 8])
        with torch.enable_grad():
               output = torch.nn.functional.conv2d(input, weight, bias, stride, padding)
               h = output.shape[2]
               w = output.shape[3]  
               # output from forward with size = ([batch_size, 128,4,4])

               output= output.view(output.shape[0], output.shape[1], -1) # output size = ([batch_size, 128,16])
       
               cont = torch.tensor([0.]).to(dev).requires_grad_(True)
               for i in range(0, output.shape[0]):
                   for f in range(len(output[i])):
                       Zi_unnormalized = output[i][f]
                       Zi = torch.nn.functional.normalize(Zi_unnormalized, dim = 0)
                       # Zj and Zk are tensors made from output[i][*] and output[other than i][*]. Zj and Zk varies for each Zi (or f)

                       Zi_Zk = torch.Tensor([0]).to(dev).requires_grad_(True)
                       for k in Zk:
                           k= torch.nn.functional.normalize(k, dim = 0)
                           zi_zk = ...
                           Zi_Zk = Zi_Zk.add(zi_zk)

                       # Similarly computing Zi_Zj
                       # Li = some algebra of Zi_Zj and Zi_Zk
                       # number of 'Li' values =  output.shape[0] * output.shape[1]
                       cont = cont.add(Li)   # 1 value
        print("\n Loss: ", cont_loss, cont_loss.requires_grad)
           
        # weight1 = weight.clone().requires_grad_(True)
        # bias1 = bias.clone().requires_grad_(True)

        # weight.shape = ([128, 96, 5, 5])
        cont_loss_weight  = torch.autograd.grad(outputs= cont_loss,inputs= weight, retain_graph=True)
    
        #bias.shape = ([128])
        cont_loss_bias = torch.autograd.grad(outputs= cont_loss, inputs= bias, retain_graph=True)
        
        output = output.view(output.shape[0], output.shape[1], h,w)
        ctx.save_for_backward(input, weight, bias, output, cont_loss, cont_loss_weight,cont_loss_bias)            

        return output    #output's shape = ([[batch_size= 100,128, 4, 4])

    @staticmethod
    def backward(ctx, grad_output):    # grad_output size = ([batch_size, 128,4,4])
        
        input, weight, bias, output,  cont_loss,cont_loss_weight,cont_loss_bias = ctx.saved_tensors    #input size = ([batch_size, 96,8,8])  
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output) #shape = ([batch_size,96,8,8])
              
        if ctx.needs_input_grad[1]:
            grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)  #shape = ([128,96,5,5])
            grad_weight += cont_loss_weight 
                        
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum((0,2,3))        #shape = ([128])         
            grad_bias += cont_loss_bias

        if bias is not None:
            return grad_input, grad_weight, grad_bias, None, None
        else:
            return grad_input, grad_weight, None, None

Then I observed that cont_loss_weight is a tuple object containing two tensors each of shape ([96, 5, 5]). It should have returned a tensor of shape ([128, 96, 5, 5]) instead of tuple. And similarly for cont_loss_bias, a tensor of shape ([128]).
I don’t know why!

Moreover, when I do `cont_loss_weight = torch.autograd.grad(outputs= cont_loss,inputs= weight, retain_graph=True), I am guessing grad_weight in backward will get affected. I have to keep retain_graph= True as well.

So to avoid that, when I used a copy of parameters i.e. cont_loss_weight = torch.autograd.grad(outputs= cont_loss,inputs= weight1, retain_graph=True) I got this error

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

cont_loss_weight = torch.autograd.grad(outputs= cont_loss, inputs= weight, retain_graph=(True))
is getting executed but it neither show any error nor it returns something because its following line: print ("Shape:", cont_loss_weight .shape) doesn’t get printed.

This line would print at least “Shape:” irrelevant of the result from the previous line. So these lines just never get run.
Also autograd.grad always returns a tuple, even if you have a single Tensor to inputs.

As mentioned before, I don’t think the custom Function approach is the simplest here. Especially if you’re not already familiar with the mathematical definition and the specific constructs.

Sorry, I am not getting. Why doesn’t it reach to the line print ("Shape:", cont_loss_weight .shape) and print the shape? I don’t know why autograd running the same loop again and again ( I suspect something similar to this thread) because the same loss getting printed as you can see the output I was getting :point_down:

What the output I expected when I print ("Shape:", cont_loss_weight[0].shape) is :point_down: because it should return a tuple of size 1 which has to be basically tensor of shape: ([128,96,5,5]) :

Loss:  tensor([37.218], device='cuda:0', grad_fn=<AddBackward0>) True
Shape: ([128,96,5,5])
Loss:  tensor([different loss value for next batch], device='cuda:0', grad_fn=<AddBackward0>) True
Shape: ([128,96,5,5])
Loss:  tensor([different loss value for next to next batch], device='cuda:0', grad_fn=<AddBackward0>) True
.
.

If your autograd.grad calls this same backward function again then yes you will end up doing infinite recursion. :slight_smile:

But I don’t think this is the case here. Because you see
Static forward:
Batch comes as input, output from convolution returned

Static backward:
We have saved tensors of weight, bias, output. cont_loss is calculated from output of forward. and within static backward only the backpropagation (cont_loss w.r.t weight,bias) happens. And for one batch static backward for this layer will be called once only (Hence, within backward backpropagation also happens once for one batch.) :man_shrugging: :smile:
Right?

Isn’t the cont_loss that you compute during the backward computed based on the forward’s output?
If so, when you try to get gradients wrt to the weights, it will backrprop through this custom Function again trying to get gradients for the weights given gradients from the output. And so you will infinite recurse right?

It is indeed :grin:
Yes, I got your point.

I had that doubt, that’s why I tried to to this forward like in this code.
I am getting correct shapes of cont_loss_weight and cont_loss_bias, but I doubt if that’s a correct way!! What are your thoughts? :thinking: