How does clone interact with backpropagation?

I was wondering, how does .clone() function interact with backpropagation?

For example I had to pieces of code and I don’t understand what their difference is, can someone explain me the difference (maybe so I understand why clone is even needed for backprop to work “properly”, whatever properly means):

E.g1:

import torch
from torch.autograd import Variable
x = Variable(torch.rand(2,1),requires_grad=False) #data set
w = Variable(torch.rand(2,2),requires_grad=True) #first layer
v = Variable(torch.rand(2,2),requires_grad=True) #last layer
w[:,1] = v[:,0]
y = torch.matmul(v, torch.matmul(w,x) )

vs

import torch
from torch.autograd import Variable
x = Variable(torch.rand(2,1),requires_grad=False) #data set
w = Variable(torch.rand(2,2),requires_grad=True) #first layer
v = Variable(torch.rand(2,2),requires_grad=True) #last layer
w[:,1] = v[:,0]
y = torch.matmul(v, torch.matmul(w.clone(),x) )

why do we need clone in the second one? What does “backprop work properly” mean?

1 Like

Both of your scripts fail for me.

Short reason:

RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

If you want to perform in-place operations (like w[:, 1] = v[:, 0]), you can clone the variable before the in-place operation. This makes it not be a leaf variable anymore.

w1 = w.clone()
w1[:,1] = v[:,0]
y = torch.matmul(v, torch.matmul(w1,x) )

Using in-place Variables is very tricky in many cases. A lot of frameworks doesn’t support them, so they just perform copies instead. PyTorch support in-place operations, but because other operations can require the content of the Variable to perform backpropagation, you can’t modify it inplace or else you will have wrong gradients.

3 Likes

If I get a copy of a Variable and extract a slice of the copy, then the gradient only back prop through the sliced dimension?
what if this Variable is the model parameters?

1 Like

Funny, I was wondering exactly what the title of the question is asking. How does backprop interact to clone?

I build this simple script:

## print clone backward
import torch
​
a = torch.tensor([1,2,3.], requires_grad=True)
c = a.sigmoid()
c_cloned = c.clone()
​
print(f'c_cloned: {c_cloned}')

and got this:

c_cloned: tensor([0.7311, 0.8808, 0.9526], grad_fn=<CloneBackward>)

but notice that the gradient recorded the clone operation (i.e. grad_fn=<CloneBackward>) which I thought was very weird.

So my question is…what does backprop react to a clone operation? What does it even mean backprop through a clone operation?

1 Like

All incoming gradients to the cloned tensor will be propagated to the original tensor as seen here:

x = torch.randn(2, 2, requires_grad=True)
y = x.clone()
y.retain_grad()

z = y**2
z.mean().backward()

print(y.grad)
print(x.grad)
8 Likes

Hi Ptrblck,

Just want to double check, it will be the same for “SliceBackward”?

Yes, the backward function of the slicing operation will propagate the gradient to the slices values in the original tensor as seen here:

x = torch.randn(5, 5, requires_grad=True)
y = x[3:4, 3:5]
print(y)
# tensor([[-0.0479,  0.8982]], grad_fn=<SliceBackward0>)

y.mean().backward()
print(x.grad)
# tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
#         [0.0000, 0.0000, 0.0000, 0.5000, 0.5000],
#         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
1 Like

Thank you for the help! One follow-up question: is there a way slice matrix in batch-wise?
For example:

output = model(input) , where output has a shape of 10 x 1 x 32 x 32 (batch size x channel x height x width)

Slicing operation will be different for each sample:
sample 1: output[0, :, :16, :16]
sample 2: output[0, :, :15, :15]
sample 3: output[0, :, :26, :26]…something like that.

Does PyTorch have a way to do the above operation efficiently? without using for loop.

I guess you can use advanced indexing only if the size of different slices matches. e.g.:
sample 1: output[0, :, 1:16, 1:16]
sample 2: output[0, :, :15, :15]
sample 3: output[0, :, 11:26, 11:26]

Here you can find documentation in numpy (works similarly in pytorch):
https://numpy.org/devdocs/user/basics.indexing.html#advanced-indexing