Check if a parameter accumulate gradients multiple time during back propagation

Hi all. Thanks for checking this thread. I feel very luck to be able to find a community while I’m trying to get my hands wet learning pytorch.

I’m training my neural network by running a beam search to generate certain loss,
in the computational graph, some parameter is used multiple times to calculate the loss for a training instance. So I think it makes sense that during back propagation, the parameter will accumulate gradients multiple times. But how can I verify it?

Much appreciated for any input.

If you just want to get a notification once the gradient of a specific parameter was updated, you could register a hook to this parameter and e.g. add a print statement:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2a = nn.Linear(10, 10)
        self.fc2b = nn.Linear(10, 10)
        self.act = nn.ReLU()
        
    def forward(self, x):
        x = self.act(self.fc1(x))
        x1 = self.fc2a(x)
        x2 = self.fc2b(x)
        return x1, x2


model = MyModel()
model.fc1.weight.register_hook(lambda x: print('grad accumulated in fc1'))

criterion = nn.MSELoss()
x = torch.randn(1, 10)
target = torch.randn(1, 10)

out1, out2 = model(x)
loss1 = criterion(out1, target)
loss2 = criterion(out2, target)

loss1.backward(retain_graph=True)
> grad accumulated in fc1
loss2.backward()
> grad accumulated in fc1

Alternatively, you could clone the gradient after the first backward pass and compare it to the gradient after the second backward call. However, for the sake of debugging I would personally prefer the first approach.

Let me know, if this would work for you.

Thanks. But my loss is actually the overall loss, which means that I assume it will remember all the operations done on that parameter. I’m not sure how I can clone the gradient after the first backward pass…

Are you dealing with a single loss, which contains multiple computation graphs involving a specific parameter?
If so, my suggestion won’t work, as the grad attribute will only be updated once.
Could you post a minimal example?

loss = (a1A + a2A + a3A )- (b1A + b2A + b3A +b4*A)

the original a1 a2 a3 are some padded matrices representing some character strings
same for b1, b2, b3, b4

The values for a2 is dependent on a1, but this should not be in the computational graph. Same, the values for a3 is dependent on a2.

However, the value of a1 is the output of feeding some values to lstm, same for a2, a3, b1 …

So I expect that during back propagation, the loss will be back propagated to the b4A operation then through the lstm, and then through the b3A operation then throught the lstm etc…
They use the same LSTM, therefore I expect the parameters in LSTM are back propagated multiple times…

The backward call will follow all computation graphs and update the parameters accordingly.
If a1, a2 etc. were computed using the same model and thus the same parameters, the gradient will be accumulated in them.
Let’s have a look at my dummy model::

# loss = loss1
torch.manual_seed(2809)
model = MyModel()
criterion = nn.MSELoss()
x = torch.randn(1, 10)
target = torch.randn(1, 10)

out1, out2 = model(x)
loss1 = criterion(out1, target)
loss2 = criterion(out2, target)
loss = loss1
loss.backward()
grad1 = model.fc1.weight.grad.clone()

# loss = loss2
torch.manual_seed(2809)
model = MyModel()
criterion = nn.MSELoss()
x = torch.randn(1, 10)
target = torch.randn(1, 10)

out1, out2 = model(x)
loss1 = criterion(out1, target)
loss2 = criterion(out2, target)
loss = loss2
loss.backward()
grad2 = model.fc1.weight.grad.clone()

# loss = loss1 + loss2
torch.manual_seed(2809)
model = MyModel()
criterion = nn.MSELoss()
x = torch.randn(1, 10)
target = torch.randn(1, 10)

out1, out2 = model(x)
loss1 = criterion(out1, target)
loss2 = criterion(out2, target)
loss = loss1 + loss2
loss.backward()
grad3 = model.fc1.weight.grad.clone()

torch.allclose(grad1 + grad2, grad3)
> True

As you can see, the backward calls on each loss separately yield the same result as a sum as the backward call on the summed losses.

If you check the grad_fn of loss, you’ll see multiple paths:

print(loss.grad_fn.next_functions)
> ((<MseLossBackward at 0x7f8cd85b0c18>, 0),
 (<MseLossBackward at 0x7f8cd85b0ba8>, 0))

Thanks! I think the paths of the grad_fn is what i’m looking for. I’ll try to check it.

Much appreciated!