here’s a more precise and fuller example. What you are doing in my example is to completely avoid autograd’s automatic backward computation and manually reverse-computing the backward graph.
For anyone coming here with a search, my solution is a hack, it is not good practice. it is given as an illustration just to showcase to @zazzyy how to shortcut these things
import torch
import torch.nn as nn
from torch.autograd import Variable
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.layers = nn.ModuleList([
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
nn.Linear(10, 10),
])
def forward(self, x):
self.output = []
self.input = []
for layer in self.layers:
# detach from previous history
x = Variable(x.data, requires_grad=True)
self.input.append(x)
# compute output
x = layer(x)
# add to list of outputs
self.output.append(x)
return x
def backward(self, g):
for i, output in reversed(list(enumerate(self.output))):
if i == (len(self.output) - 1):
# for last node, use g
output.backward(g)
else:
output.backward(self.input[i+1].grad.data)
print(i, self.input[i+1].grad.data.sum())
model = Net()
inp = Variable(torch.randn(4, 10))
output = model(inp)
gradients = torch.randn(*output.size())
model.backward(gradients)