Inconsistent behavior in torch.repeat()'s backprop


When having two tensors: A and X where A is constant (requires_grad=False) and has the shape [100, 100, 1,1], and X is a trainable parameter (requires_grad=True) and has the shape [1, 100, 1, 1], I would expect these 3 expressions to produce the same result:

  1. (A * X)
  2. (A * X.expand(100, 100, 1, 1))
  3. (A * X.repeat(100, 1, 1, 1))
    Since the sizes of the tensors do not match, the first dimension of X should be increased by copying the data, and then the multiplication is performed.
    As expected, all these expressions do in fact produce the same result, but there’s a tiny issue during back-propagation.

I would expect X’s gradients in all 3 cases, since they are mathematically identical, to be the same, but it turns out lines 1&2 consistently get the exact same result, while line 3 always has minor changes.

I understand that these operations will have to add multiple values along the first dimension during back-propagation, which are probably translated to CUDA “atomicAdd” operations that are known to be nondeterminism as explained here, but isn’t it weird that lines 1 and 2 always get the same results, and line 3 is the only non-deterministic calculation?
I don’t understand how can torch.repeat() be different than torch.expand() during back-propagation, given that torch.expand() in this case is deterministic and gives the exact same results as (A*X).

The code below assumes that line 1 (that calculated (A*X)) has the correct result, and first compares the gradient values of it to the gradients produced by using torch.expand(), then uses torch.repeat() to show that it gives different results.


import torch

device = 'cuda'
size = 100
dtype = torch.float

a = torch.rand(size, size, 1, 1, device=device, dtype=dtype)

x = torch.ones(1,   size, 1, 1, dtype=dtype, device=device, requires_grad=True)
y = torch.ones(1,   size, 1, 1, dtype=dtype, device=device, requires_grad=True)

(a * x).sum().backward() # Calculating the expected result, expression #1

(a * y.expand(size,size,1,1)).sum().backward() # Expression #2 -  using torch.expand()
torch.equal(x.grad,y.grad) # This always returns True, as expected.

y.grad = None   # Clearing gradients between iterations

(a * y.repeat(size,1,1,1)).sum().backward() # Expression #3 -  using torch.repeat()
torch.equal(x.grad,y.grad) # This always returns False even though it performs the exact same operation!!!

Additional comments:

  1. When running on CPU (device = ‘cpu’) everything works as expected (torch.expand() gives the same result as (A*X) and torch.repeat()).
  2. The behavior stays the same when using dtype=torch.double.
  3. Everything works as expected when changing X and Y to torch.double, but generating A in floating-point, and only after the generation moving it to double (by using .double()).
  4. When reducing the size to below 5, everything works as expected (worked more than 20 times in a row), when using size=5, it has about 50% chance of working.



There is one major difference between these 3 cases: .repeat() is the only one that actually allocates memory for the “larger version of X”.
One possible explanation for what you see is that it has more memory to read when doing the multiplication and reductions, a different kernel launching configuration needs to be used. And so operation can happen in a different order which can lead to small differences I’m afraid.

Also checking the implementation, it seems like they don’t call the exact same sum method when doing the reduction, which might explain the 1-bit derror as well.

Hi @albanD,

Thank you for your response.

Can you please guide me to where these functions are implemented? I’m having trouble finding it by myself.


The backward of expand is defined here where sum_to() is a nice wrapper that will call sum with a list of dimensions to reduce.
On the other hand, the backward of repeat in this same file is repeat_backward() which is defined here (most likely was implemented before we added support for multiple dim to sum).

1 Like