Change the gradient during backward pass

Below is the output and its code. I want to change the gradient during the backward pass. For example out1.grad[0][0] is -22 which i want to modify it to 50 manually.

output and loss is  tensor([8., 8.], grad_fn=<MulBackward0>) tensor(386.5000, grad_fn=<MseLossBackward>)
tensor([-22., -17.])
output and loss is  tensor([289.6000, 225.6000], grad_fn=<MulBackward0>) tensor(53816.2656, grad_fn=<MseLossBackward>)
tensor([9397.5205, 5656.9204])
import torch.nn as nn
class Net(torch.nn.Module):
    def __init__(self):       
        super(Net, self).__init__()
        self.w=torch.nn.Parameter(torch.ones(2))
    def forward(self, x):
        out1= 2 * x
        out1.register_hook(lambda grad: print(grad))
        out = torch.mul(out1,self.w)
        return out

   
learning_rate = 0.2
model=Net()
criterion=torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

x=4*torch.ones((2),requires_grad=True)
target=torch.FloatTensor([30,25])

for i in range (0,2):
    optimizer.zero_grad()
    out=model(x)
    loss= criterion(target,out)
    print("output and loss is ",out,loss)
    loss.backward(retain_graph=True)
    optimizer.step()  ```

Hi,

Your hook can return a new Tensor that will then be used instead of the given Tensor.
Note that hooks should never modify their input inplace. So you can do:

def my_hook(grad):
    grad = grad.clone()
    grad[0][0] = 50
    print("new grad: ", grad)
    return grad

And register that as a hook for out1

I am calling it from forward as mentioned in code but it gives error “my_hook” is not defined. I make changes mentioned in code 2 but in that also it gives the same error.

import torch.nn as nn
class Net(torch.nn.Module):

    def __init__(self):       
        super(Net, self).__init__()
        self.w=torch.nn.Parameter(torch.ones(2))
       
    def my_hook(grad):
      grad = grad.clone()
      grad[0][0] = 50
      print("new grad: ", grad)
      return grad

    def forward(self, x):
        out1= 2 * x
        out1.register_hook(my_hook)
        out = torch.mul(out1,self.w)
        return out

   
learning_rate = 0.2
model=Net()
criterion=torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

x=4*torch.ones((2),requires_grad=True)
target=torch.FloatTensor([30,25])

for i in range (0,2):
    optimizer.zero_grad()
    out=model(x)
    loss= criterion(target,out)
    print("output and loss is ",out,loss)
    loss.backward(retain_graph=True)
    optimizer.step()

Code:2

import torch

import torch.nn as nn

class Net(torch.nn.Module):

    def __init__(self):       

        super(Net, self).__init__()

        self.w=torch.nn.Parameter(torch.ones(2))

        self.out1=torch.ones(2)

       

    def my_hook(self,grad):

      grad = grad.clone()

      grad[0][0] = 50

      print("new grad: ", grad)

      return grad

    def forward(self, x):

        self.out1= 2 * x

        # out1.register_hook(my_hook)

        out = torch.mul(self.out1,self.w)

        return out

   

learning_rate = 0.2

model=Net()

criterion=torch.nn.MSELoss()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

x=4*torch.ones((2),requires_grad=True)

target=torch.FloatTensor([30,25])

for i in range (0,2):

    optimizer.zero_grad()

    out=model(x)

    loss= criterion(target,out)

    print("output and loss is ",out,loss)

    loss.backward(retain_graph=True)

    model.out1.register_hook(my_hook)

    optimizer.step()

Defining my_hook on the class will make it a class method. And so you need to specify it as self.my_hook.
Otherwise you can put it inline in the function:

import torch
import torch.nn as nn
class Net(torch.nn.Module):

    def __init__(self):       
        super(Net, self).__init__()
        self.w=torch.nn.Parameter(torch.ones(2))
       

    def forward(self, x):
        def my_hook(grad):
            grad = grad.clone()
            grad[0] = 50
            print("new grad: ", grad)
            return grad

        out1= 2 * x
        out1.register_hook(my_hook)
        out = torch.mul(out1,self.w)
        return out

   
learning_rate = 0.2
model=Net()
criterion=torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

model(torch.rand(2, requires_grad=True)).sum().backward()
1 Like