Inconsistent behavior in torch.repeat()'s backprop

Hi,

When having two tensors: A and X where A is constant (requires_grad=False) and has the shape [100, 100, 1,1], and X is a trainable parameter (requires_grad=True) and has the shape [1, 100, 1, 1], I would expect these 3 expressions to produce the same result:

  1. (A * X)
  2. (A * X.expand(100, 100, 1, 1))
  3. (A * X.repeat(100, 1, 1, 1))
    Since the sizes of the tensors do not match, the first dimension of X should be increased by copying the data, and then the multiplication is performed.
    As expected, all these expressions do in fact produce the same result, but thereā€™s a tiny issue during back-propagation.

I would expect Xā€™s gradients in all 3 cases, since they are mathematically identical, to be the same, but it turns out lines 1&2 consistently get the exact same result, while line 3 always has minor changes.

I understand that these operations will have to add multiple values along the first dimension during back-propagation, which are probably translated to CUDA ā€œatomicAddā€ operations that are known to be nondeterminism as explained here, but isnā€™t it weird that lines 1 and 2 always get the same results, and line 3 is the only non-deterministic calculation?
I donā€™t understand how can torch.repeat() be different than torch.expand() during back-propagation, given that torch.expand() in this case is deterministic and gives the exact same results as (A*X).

The code below assumes that line 1 (that calculated (A*X)) has the correct result, and first compares the gradient values of it to the gradients produced by using torch.expand(), then uses torch.repeat() to show that it gives different results.

Code:

import torch

device = 'cuda'
size = 100
dtype = torch.float

a = torch.rand(size, size, 1, 1, device=device, dtype=dtype)

x = torch.ones(1,   size, 1, 1, dtype=dtype, device=device, requires_grad=True)
y = torch.ones(1,   size, 1, 1, dtype=dtype, device=device, requires_grad=True)

(a * x).sum().backward() # Calculating the expected result, expression #1


(a * y.expand(size,size,1,1)).sum().backward() # Expression #2 -  using torch.expand()
torch.equal(x.grad,y.grad) # This always returns True, as expected.

y.grad = None   # Clearing gradients between iterations

(a * y.repeat(size,1,1,1)).sum().backward() # Expression #3 -  using torch.repeat()
torch.equal(x.grad,y.grad) # This always returns False even though it performs the exact same operation!!!

Additional comments:

  1. When running on CPU (device = ā€˜cpuā€™) everything works as expected (torch.expand() gives the same result as (A*X) and torch.repeat()).
  2. The behavior stays the same when using dtype=torch.double.
  3. Everything works as expected when changing X and Y to torch.double, but generating A in floating-point, and only after the generation moving it to double (by using .double()).
  4. When reducing the size to below 5, everything works as expected (worked more than 20 times in a row), when using size=5, it has about 50% chance of working.

Thanks.

Hi,

There is one major difference between these 3 cases: .repeat() is the only one that actually allocates memory for the ā€œlarger version of Xā€.
One possible explanation for what you see is that it has more memory to read when doing the multiplication and reductions, a different kernel launching configuration needs to be used. And so operation can happen in a different order which can lead to small differences Iā€™m afraid.

Also checking the implementation, it seems like they donā€™t call the exact same sum method when doing the reduction, which might explain the 1-bit derror as well.

Hi @albanD,

Thank you for your response.

Can you please guide me to where these functions are implemented? Iā€™m having trouble finding it by myself.

Hi,

Sure!
The backward of expand is defined here where sum_to() is a nice wrapper that will call sum with a list of dimensions to reduce.
On the other hand, the backward of repeat in this same file is repeat_backward() which is defined here (most likely was implemented before we added support for multiple dim to sum).

1 Like