Setting custom gradient for backward

Hi,

Normally one would use a forward pass to calculate a loss and then perform a backward pass and the gradient is automatically calculated.

My situation is that I don’t have the loss but I have the gradient calculated. My question is how do I set a custom gradient to a network (fully connected) and run the backward optimization using the custom gradient? Thanks

You could assign the precomputed gradients to the .grad attribute of all parameters and call optimizer.step() afterwards. Here is a small example:

lin = nn.Linear(1, 1, bias=False)
optimizer = torch.optim.SGD(lin.parameters(), lr=1e-3)

print(lin.weight)
lin.weight.grad = torch.ones_like(lin.weight) * 10
print(lin.weight.grad)
optimizer.step()
print(lin.weight)
1 Like

Thank you! What if I have a network? Like a two layers network. Can I assign grad to the last layer and then apply backward() to perform the back propogation for the first layer?

Yes, you could manipulate the gradients manually of any layer using the previous code snippet or you could also use hooks (via register_hook), which might be cleaner especially if you are using multiple layers.

Thank you for your answer.
I realized that I have a slightly different problem here.

Rather than changing the weight of the layer, what I would like to do is to set the gradient of the loss w.r.t the output of my network. So dLoss/dOutput, which is not part of the network layer. It is more like setting a grad for a tensor variable? To be more specific, my pseudo code would be:

output = myNetwork(input)  

my_grad = getMyGradient()  #my custom dLoss/dOutput

output.grad = my_grad
output.backward() #usually backward starts from loss, but I would like to start from the output.
optimizer.step()

What would be the correct implementation for this kind of task?
Thank you in advance.

You can pass the gradient directly to the backward operation as an argument: output.backward(my_grad).

1 Like

It is working! Thank you!

Hi, ptrblck

Could you give an example of using hooks to manipulate the gradient of each layer (say a three-layer fully connected network)? I am new to PyTorch and never use hooks before.

Sure, here is a small example:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1, 2)
        self.fc2 = nn.Linear(2, 2)
        self.fc3 = nn.Linear(2, 1)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x
    
model = MyModel()
x = torch.randn(1, 1)

model.fc1.weight.register_hook(lambda grad: grad+1000.)
model.fc2.bias.register_hook(lambda grad: grad-1000.)
model.fc3.weight.register_hook(lambda grad: grad*1000.)

out = model(x)
out.mean().backward()
print([(name, param.grad) for name, param in model.named_parameters()])
# [('fc1.weight', tensor([[ 999.7733],
#         [1000.0000]])), ('fc1.bias', tensor([-0.2406,  0.0000])), ('fc2.weight', tensor([[0.0000, 0.0000],
#         [0.3678, 0.0000]])), ('fc2.bias', tensor([-1000.0000,  -999.4205])), ('fc3.weight', tensor([[ 0.0000, 67.0654]])), ('fc3.bias', tensor([1.]))]
1 Like

Hi, ptrblck.
I want to train a model, its forward path includes three part A, B and C,sequentially.
The main problem I encountered is that B is not differentiable,
so I replace it with a simulation model B_prime in backward part for gradient computation.
(But still use original B in forward path)
Any suggestion regarding how to build a custom backpropagation structure?
Thanks!

If you are using non-differentiable operations and want to implement a custom backward you could defined a custom autograd.Function as described here.

Hi, thanks for your reply.
I’m using autograd.Function with nn.module now, but I get this error when calling loss.backward() for mySys.
I’ve seen your reply below and found my custom backward function’s output sim_out is detached from the computation graph(.grad_fn=None).
Here is a snippet of my code:

class B_nd(torch.autograd.Function):
    def forward(ctx, input):
        ctx.save_for_backward(input)
        # some non-differentiable functions
        return out

    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        input.requires_grad_(True)
        sim_model = UNet()
        sim_model.load_state_dict(torch.load('/Desktop/CodeFolder/2023/pth/800.pth', map_location='cuda:0'))
        sim_model.to(torch.device("cuda"))
        sim_model.zero_grad()
        sim_out = sim_model(input)
        print(sim_out.grad_fn)
        sim_out.backward(gradient=grad_output)
        gradient = input.grad
        return gradient

class mySys(nn.Module):
    def __init__(self):
        super(mySys, self).__init__()
        self.A = nn.Conv2d(3, 3, kernel_size=1)
        self.C = nn.Conv2d(3, 3, kernel_size=1)
    
    def forward(self, x):
        a_out = self.A(x)
        b_out = B_nd.apply(a_out)
        c_out = self.C(b_out)
        return c_out

Any method or suggestion to attach the UNet in the backward part of B to the computation graph?

I don’t fully understand the explanation as nothing is detached in your code.
You might just need to enable gradient calculation in your backward via torch.set_grad_enabled(True).