Does pytorch do eager pruning of its computational graph?

This is a very simple example:

import torch

x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)

s = torch.sum(x * y * z)
s.backward()

print(x.grad)

This will print,

tensor([2., 2., 0., 0., 0.]),

since, of course, ds/dx is zero for the entries where z is zero.

My question is: Is pytorch smart and stop the computations when it reaches a zero? Or does in fact do the calculation “2*5” , only to later do “10 * 0 = 0” ?

In this simple example it doesn’t make a big difference, but in the (bigger) problem I am looking at, this will make a difference.

Thank you for any input.

PS: question also posted here: https://stackoverflow.com/questions/54781966/does-pytorch-do-eager-pruning-of-its-computational-graph

Hi,

No it does not.
The thing is that a tensor full of 0 would be a stop just for multiplication. Other operations might still continue even after a 0 is encountered. So this cannot be done in general I think (even though it might in some very specific cases).