How to split backward process wrt each layer of neural network?

here’s a more precise and fuller example. What you are doing in my example is to completely avoid autograd’s automatic backward computation and manually reverse-computing the backward graph.

For anyone coming here with a search, my solution is a hack, it is not good practice. it is given as an illustration just to showcase to @zazzyy how to shortcut these things

import torch
import torch.nn as nn
from torch.autograd import Variable

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.layers = nn.ModuleList([
            nn.Linear(10, 10),
            nn.Linear(10, 10),
            nn.Linear(10, 10),
            nn.Linear(10, 10),
        ])

    def forward(self, x):
        self.output = []
        self.input = []
        for layer in self.layers:
            # detach from previous history
            x = Variable(x.data, requires_grad=True)
            self.input.append(x)

            # compute output
            x = layer(x)

            # add to list of outputs
            self.output.append(x)
        return x

    def backward(self, g):
        for i, output in reversed(list(enumerate(self.output))):
            if i == (len(self.output) - 1):
                # for last node, use g
                output.backward(g)
            else:
                output.backward(self.input[i+1].grad.data)
                print(i, self.input[i+1].grad.data.sum())

model = Net()
inp = Variable(torch.randn(4, 10))
output = model(inp)
gradients = torch.randn(*output.size())
model.backward(gradients)
1 Like