Gradients through multiple forwards

Hi all,

I’m trying to compute the loss gradient using a model multiple times. Unfortunately, my GPU is small, so I cannot calculate the backprop directly from the loss function.
ATM, I tried to compute it like this:

x_in  # input image
x_in.requires_grad = True
x_t = torch.clone(x_in)
for i in range(50):
    x_t = model(x_t)
loss = criterion(x_t)
grads = torch.autograd.grad(loss, x_in)[0]

Knowing the problem, I tested it using the chain rule iteratively, i.e., I computed the derivatives of the output tensor wrt the input tensor. Something like this:

x_in  # is an input image
x_t = x_in.detach()
grads = torch.ones_like(x_t)

for i in range(50):
    x_in = x_t.detach()
    x_in.requires_grad = True
    x_t = model(x_in)

    # compute the derivatives wrt the output
    prev_grads = torch.autograd.grad(x_t.sum(), x_in)[0]

    # chain rule
    with torch.no_grad():
        grads *= prev_grads

x_in = x_t.detach()
x_in.requires_grad = True
loss = criterion(x_in)

return grads * torch.autograd.grad(loss, x_in)[0]

If I do it this way, it works. Nonetheless, I am unsure if what I’ve done is correct. Can you help me by giving me some insights into correctly doing it?

Thank you very much


I’m not sure about your approach. If your GPU is small, you can try using gradient checkpointing (nice visual explanation) with torch.utils.checkpoint in your model’s forward.

Thank you very much! I didn’t know that function existed.