The memory is from the output out1
and intermediate activations needed to compute the gradient. The first increase is from computing out1
. The second increase is from computing net(data1)
while out1
is still alive. The reason is that in:
out1 = net(data1)
The right-hand side net(data1)
is evaluated before the assignment. Memory usage, as reported by the system, doesn’t generally decrease. If it had, then it would decrease back to 2872Mi
after the assignment operation.
You can rewrite your program to avoid keeping two versions of out1
alive at once:
def eval(network, input):
out1 = network(input)
# maybe use out1 here
for i in range(10):
eval(net, data1)
As long as you don’t return out1
from eval
, out1
will be freed before the next call, so you’ll only use 2872Mi
.