Different forward and backward weights

I have a use case, where I need to use a different set of weights to compute the backward pass. An instance of where this is used is in this work https://www.nature.com/articles/ncomms13276 and numerous follow up works. Or this https://www.siarez.com/projects/random-backpropogation

Is there anyway to do this with the current API? If not what would be the best angle of attack?


You could try to load the backward state_dict before executing the backward operation, but it’s quite a hacky way:

model = nn.Sequential(
    nn.Linear(1, 1, bias=False),
    nn.Linear(1, 1, bias=False)

with torch.no_grad():

sd_forward = copy.deepcopy(model.state_dict())
sd_backward = copy.deepcopy(sd_forward)

# one train step
output = model(torch.ones(1, 1))
> tensor([[10.]])
> tensor([[1.]])

Also note, that the last gradient is wrong, since the output was calculated using the old weights.
Would this approach work for you or did I misunderstand your question?

@ptrblck Thanks for your reply. I actually figured out the right way to do this. I basically wrote my own Autograd Function class, similar to here: https://pytorch.org/docs/stable/notes/extending.html
Then inside its backward method, I used a different set of weights to compute grad_input. So the backward now looks like this:

    def backward(ctx, grad_output):
        input, weight, b_weights, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(b_weights)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[3]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, torch.zeros_like(b_weights), grad_bias

b_weights are the backwards weights that are passed to the forward function and saved in ctx

1 Like

That looks like a good approach! Thanks for sharing it! :slight_smile:

@ptrblck, I have a doubt, might be silly to ask🤐.
But anyways, when we call loss.backward, will grad_input, grad_output, grad_weight and grad_bias become like these?

grad_input = gradient of loss w.r.t input,
grad_output = gradient of loss w.r.t output,
grad_weight = gradient of loss w.r.t weight,
grad_bias = gradient of loss w.r.t bias.

And what do grad_input, grad_output, grad_weight and grad_bias mean as per this tutorial? Gradients of what and w.r.t. what?

The weight and bias gradients would be the gradients of the loss w.r.t these parameters.
The input and output gradients are calculated through the chain rule to forward the gradient to the next layer (previous layer during the forward pass).

I am getting confused because when I checked the documentations behind

grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output)
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output)

I came to know the input is the input to the (convolution) layer and [this] (https://github.com/pytorch/pytorch/blob/master/torch/nn/grad.py) code file also says that conv2d_weight function (line 170) computes the gradient of output of convolution with respect to the weight of the convolution.

So, where does the gradient of loss w.r.t. parameters is calculated if you take for example class LinearFunction(Function): in extending pytorch tutorial?

I’m not sure if I understand the question properly, but you would apply the chain rule and thus the conv output would be used.
The general workflow of the chain rule and backpropagation is explained e.g. in CS231n - Optimization.

Thank you for replying.

I am familiar with the chain rule, but I don’t know where exactly gradients of loss wr.t. parameters are calculated in code if I use autograd function as described in extending pytorch tutorial.

Can you please look at below code snippet in which I have commented my doubts, I think this would be a better way to clarify doubts :sweat_smile:

class Custom_Convolution(torch.autograd.Function):    
    def forward(ctx, input, weight, bias, stride, padding):  # input's shape = ([batch_size=100, 96, 8, 8])
        output = torch.nn.functional.conv2d(input, weight, bias, stride, padding)  
        ctx.save_for_backward(input, weight, bias, output)
        return output    #output's shape = ([[batch_size= 100,128, 4, 4])
    def backward(ctx, grad_output):  #grad_output size = ([batch_size= 100, 128, 4, 4])    
        input, weight, bias, output = ctx.saved_tensors      #input's size = ([batch_size=100 , 96,8,8])
        print("op:",output.shape, output.requires_grad, output.grad_fn )
## It shows, output requires gradient and grad_fn = <torch.autograd.function.Custom_ConvolutionBackward object at ...>

## I am cloning the output because I think it will override the gradients 
## of already existing output tensor which may affect further calculations 
## of grad_input, grad_weight and grad_bias. 

        features = output.clone()
        print("op2:",features.requires_grad, features.grad_fn )
## It prints: False , None  !!!!

        features = features.view(features.shape[0], features.shape[1], -1)
        #Total_features=  features.shape[0]* features.shape[1]
        cont_loss = torch.tensor([0.]).requires_grad_(requires_grad=True).to(dev)  # shape: ([1])
        for ..... :
                # My code for loss...   includes some operations like torch.div,exp,sum...
                # Calculation of loss for each feature 'i' : Li
                # cont_loss  +=  Li (Number of Li values = features.shape[0]* features.shape[1])

I want to backpropagate from cont_loss to features (i.e. output) and then features to weight tensor.
So, when I use torch.autograd.grad(outputs= cont_loss, inputs= weight , retain_graph=(True)),

I am getting RuntimeErrors like
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior / One of the tensor used in computational graph either does not require gradient or it has No gradient function.

Based on your comments it seems you would like to apply something like a second order gradients, since you want to create grad_fns inside the backward. If that’s the case, enable the gradient calculation via with torch.enable_grad():.

Hi, I modified my code but this is happening.