This is a very simple example:
import torch
x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)
s = torch.sum(x * y * z)
s.backward()
print(x.grad)
This will print,
tensor([2., 2., 0., 0., 0.]),
since, of course, ds/dx is zero for the entries where z is zero.
My question is: Is pytorch smart and stop the computations when it reaches a zero? Or does in fact do the calculation “2*5
” , only to later do “10 * 0 = 0
” ?
In this simple example it doesn’t make a big difference, but in the (bigger) problem I am looking at, this will make a difference.
Thank you for any input.
PS: question also posted here: https://stackoverflow.com/questions/54781966/does-pytorch-do-eager-pruning-of-its-computational-graph