Even more interesting cases! Now I’m lost by this simple Clone
gradient test:
import torch
class Clone(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = x.clone()
return y
@staticmethod
def backward(ctx, y_grad):
print(f'inside backward, y_grad.stride(): {y_grad.stride()}')
print(f'inside backward, y_grad.data_ptr(): {y_grad.data_ptr()}')
return y_grad # to examine the side effect
# case 1
x = torch.rand((2, 2), requires_grad=True)
y = Clone.apply(x)
y.mean().backward()
print(f'done backward, x.grad.stride(): {x.grad.stride()}')
print(f'done backward, x.grad.data_ptr(): {x.grad.data_ptr()}')
# out:
# inside backward, y_grad.stride(): (2, 1)
# inside backward, y_grad.data_ptr(): 140442810773632
# done backward, x.grad.stride(): (2, 1)
# done backward, x.grad.data_ptr(): 140442810773632
# This case is easy to understand. y_grad is simply passed out.
# case 2
x = torch.rand((2, 2), requires_grad=True)
y = Clone.apply(x)
y.sum().backward()
print(f'done backward, x.grad.stride(): {x.grad.stride()}')
print(f'done backward, x.grad.data_ptr(): {x.grad.data_ptr()}')
# out:
# inside backward, y_grad.stride(): (0, 0)
# inside backward, y_grad.data_ptr(): 140442883547904
# done backward, x.grad.stride(): (2, 1)
# done backward, x.grad.data_ptr(): 140442883507008
# Interesting case, but still understandable.
# The grad of sum gives a dummy scalar-expanded tensor with stride (0, 0),
# which forces x.grad to be copy-initialized with a new tensor to match x's memory layout
# see:
# https://pytorch.org/docs/stable/autograd.html#default-gradient-layouts
# case 3:
x = torch.rand((2, 2), requires_grad=True)
y = Clone.apply(x)
y_grad = torch.rand_like(y)
print(f'before backward, y_grad.stride(): {y_grad.stride()}')
print(f'before backward, y_grad.data_ptr(): {y_grad.data_ptr()}')
y.backward(y_grad)
print(f'done backward, x.grad.stride(): {x.grad.stride()}')
print(f'done backward, x.grad.data_ptr(): {x.grad.data_ptr()}')
# out:
# before backward, y_grad.stride(): (2, 1)
# before backward, y_grad.data_ptr(): 140442809548672
# inside backward, y_grad.stride(): (2, 1)
# inside backward, y_grad.data_ptr(): 140442809548672
# done backward, x.grad.stride(): (2, 1)
# done backward, x.grad.data_ptr(): 140442809838848
# This one is the really interesting one!
# What makes torch think that x.grad needs a new copy assignment, as opposed to case 1?
# Looks like torch figured out some underlying data ownership and made such a decision,
# but how, why, and in what scenario will it do so?
The case 3 is quite interesting. To my knowledge this involves some object lifetime and ownership stuff. If it’s in C++ I can explain. The backward input y_grad
is std::move
’d to x_grad in case 1 and case 3, but in case 1 no other objects hold the underlying data so torch can decide that x.grad uniquely owns it and no extra copy is needed. while in case 3 there’s another named object (the global variable y_grad
) holding the same data, making the data reference counter > 1, hence torch decide to make an extra copy so that x.grad can uniquely own the data. BUT that’s all C++. In these python api’s, how this is done exactly? My naive understanding of python is that it uses a global garbage collector to manage everything and it’s not guaranteed which object is tmp and can be killed immediately from the stack. Today this example tells that I’m likely wrong. So python also have some local variable management on the memory stack like C++ to make similar object ownership transfer stuff? Otherwise how the above is done?