Backpropagation through time when to do .backward()?

So I am training a convLSTM network, however, I decided to check if BPTT differs in results when it is called on a cumulative loss or on each individual loss object. I experimented on the gradient over the weight w1 and it not exactly the same in two cases. Code is below, anyone can comment on that difference

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy

w1 = torch.nn.Parameter(torch.tensor(([[1,2],[4,1]]),dtype=torch.float))
x = torch.nn.Parameter(torch.tensor(([[1,5],[4,5]]),dtype=torch.float),requires_grad=False)
y = torch.nn.Parameter(torch.tensor(([[2,7],[3,4]]),dtype=torch.float),requires_grad=False)
loss_fn = torch.nn.MSELoss(reduction='mean')

t= 20

#initial grad
print(w1.grad)

#case 1 call backward outside of the loop on a cumulative loss
loss = 0
x1=x.clone()
for i in range(t):
  y_pred = w1.mm(x1)
  loss += loss_fn(y_pred, y)
  print(i, loss_fn(y_pred, y).item())
  x1= y_pred.detach()
loss.backward()

case1 = (w1.grad).data

#reset grad of w1
w1.grad= None

#case two : backward in the loop
loss = 0
x1=x.clone()
for i in range(t):
  y_pred = w1.mm(x1)
  loss += loss_fn(y_pred, y)
  loss_fn(y_pred, y).backward(retain_graph=True)
  print(i, loss_fn(y_pred, y).item())
  x1= y_pred.detach()

case2 = (w1.grad).data

print(torch.all(torch.eq(case1, case2))) #this gives false
numpy.isclose(case1, case2) #this gives true, they are close but not same

Hi,

Please avoid using .data. Not sure why you use it here as w1.grad does not require gradient. If you want a copy, you need .clone().

These gradients are expected to be close but not equal I’m afraid.
Floating point operations are not exact and can differ based on their order. Here when you backward every time or if you sum the result and then do one backward, you will perform all operations in different orders.

@albanD thanks so much. I also thought that it may be because of the floating points since gradient gets too large in this particular example. Since I wanted to compare numeric values I used .data. I will also consider your other advices, so the conclusion is they are same right?

.data is not a thing anymore since we removed Variables. It still does things but mostly buggy things so you should avoid it :smiley:

the conclusion is they are same right ?

They are as close as we can get them. So here yes they are “the same” up to floating point precision.