What you are asking for is not strictly possible without writing custom autograd.Function functions to insert these custom nodes. Either that, or you do some book-keeping and manage the graph yourself.
model = nn.ModuleList(nn.Linear(100, 200), nn.ReLU(), nn.Linear(200, 300))
x = Variable(torch.randn(10, 100), requires_grad=True)
def model_forward(model, x):
for m in model:
x = model(x)
def model_backward(model, grad_output):
for m in reversed(model):
return # shortcut outside of backward
grad_output = m.backward(grad_output)
Something of this order. It’s super hacky, and you have to do all model book-keeping yourself.