I usually define a new Function, when
backward()
is much simpler than forward()
.
Since some variables added to computation graph during Module.forward() can be meaningless.
For example,
# Inside of torch.nn.Module
def forward(x):
x1 = x + 3
x2 = x + 3
x3 = x + 3
x4 = x + 3
x5 = x + 3
return x5
here, variable x, x1, x2, x3, x4, x5
can be added in computation graph
in order to compute gradient later.
However, we only need x
to compute gradient, because there was only +
operation.
Considering this case, in order to save some memory I define new Function
,
#Inside of Function
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(X)
x1 = x + 3
x2 = x + 3
x3 = x + 3
x4 = x + 3
x5 = x + 3
return x5
@staticmethod
def backward(ctx, z):
x = ctx.saved_variables
return x.fill(1)
Is this right case of defining new Function
?
or pytorch’s autograd system automatically remove meaningless variables for backward()