Mini-batching / gradient accumulation within the model

Paper [2004.04725] Instance-aware, Context-focused, and Memory-efficient Weakly Supervised Object Detection introduced Sequential backprop - within the model mini-batching of some components to make feasible handling of large number of ROIs with per-ROI ResNet’s res5 block.

I’ve made the following sketch to wrap ResNet’s res5 layer (haven’t tested yet).

In theory, can torch.vmap replace the mini-batching loop in backward?

Thanks!

import torch
import torch.nn as nn

class SequentialBackprop(nn.Module):
    def __init__(self, module, batch_size = 1):
        super().__init__()
        self.module = module
        self.batch_size = batch_size

    def forward(self, x):
        y = self.module(x.detach())

        return self.Function.apply(x, y, self.batch_size, self.module)

    class Function(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x, y, batch_size, module):
            ctx.save_for_backward(x)
            ctx.batch_size = batch_size
            ctx.module = module
            return y

        @staticmethod
        def backward(ctx, grad_output):
            (x,) = ctx.saved_tensors
            grads = []
            for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)):
                with torch.enable_grad():
                    x_mini = x_mini.clone()
                    x_mini.retain_grad()
                    y_mini = ctx.module(x_mini)
                torch.autograd.backward(y_mini, g_mini)
                grads.append(x_mini.grad)
            return torch.cat(grads), None, None, None

if __name__ == '__main__':
    backbone = nn.Linear(3, 6)
    neck = nn.Linear(6, 12)
    head = nn.Linear(12, 1)

    model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head)

    print('before', neck.weight.grad)

    x = torch.rand(512, 3)
    model(x).sum().backward()
    print('after', neck.weight.grad)

1 Like

@alband had to do nasty clone / retain_grad hacks, otherwise it complained about no grad accumulator of leaf tensor

Or maybe should this be a feature request for torch.autograd.backward?

Or maybe should this be a feature request for torch.autograd.backward?

What would be the feature?

Note that you don’t have to do retain_grad, you can pass inputs= to backward to get the .grad field only for these Tensors.

Also I’m curious what is the benefit of doing this compared to fw/bw or the smaller batches and accumulating the gradients?

Feature will-be native mini-batching support for torch.autograd.backward. Probably they could not be made parallel because of memory saving goal.

It’s that you apply this gradient accumulation only in the middle of the model. If you do traditional gradient accumulation, you’d be underloading the GPU for the prior part of the model (the backbone).

Not sure if I follow, I need gradients both for x_mini and for neck.weight - that’s why I can’t use just torch.autograd.grad. I’ve tried and got an error:

                    #x_mini.retain_grad()
                    y_mini = ctx.module(x_mini)
                torch.autograd.backward(y_mini, g_mini, inputs = x_mini)

RuntimeError: One of the differentiated Tensors given as 'inputs' to backward is not a leaf Tensor

That should not happen on the latest version. We added that :slight_smile:

Feature will-be native mini-batching support for torch.autograd.backward. Probably they could not be made parallel because of memory saving goal.

The thing is that this is not really an autograd thing, this is only true under a specific assumption that you have a batch of independent samples. Which is not true in general when not considering cummulative loss or(as soon as you have a batchnorm.
The workaround of using a for-loop looks simple enough here.

1 Like

Will it still accumulate gradients into self.module.weight or only produce gradInput?

Also, is there a way to bypass clone() / input copy?

Will it still accumulate gradients into self.module.weight or only produce gradInput?

Only the inputs given. But you can give model.parameters() as well :wink:

Also, is there a way to bypass clone() / input copy?

I’m not 100% sure why you need to clone here, I think detaching and requiring gradient again would work.

Thanks! Seems to work. Final version: Mini-batching within the model in PyTorch · GitHub

Looks good. Note that the retain_grad is not needed any more as you already have a leaf Tensor that requires grad.