Retain sub-graph


I have net_A (+) which I want to train, and net_B (-) which is pre-trained and has requires_grad=False.
net_C is as follows:

net_A -> net_B (iterate N times)

Sample code:

for epoch in range(arg.epochs):
    for i, data in enumerate(dataloader):
        # data = batch x 3 x 28 x 28
        A = net_A(data)
        # A = batch x N x 2
        for b in range(A.size(0)):
            for n in range(A.size(1)):
                B = net_B(A[b, n:n+1])
                loss = A_loss(B, data) # some loss that involves net_B and input data
        optimizer.step() # net_A parameters

In the computation graph net_A (+) is trained and needed for every backprop, but net_B (-) branches (leafs) are changed.
retain_graph saves all (-) paths, which are not needed (only the current leaf is needed) and ends up returning “out of memory”…

How can I retain only the sub-graph (+) ?


I have been working for a week trying to solve the same problem, and I think I came up with a possible solution.
First, you shoud define the following function and module:

class GradAccumulatorFunction(Function):
    def forward(ctx, input, accumulated_grad=None, mode="release"):
        ctx.mode = mode
        return input

    def backward(ctx, grad_output):
        accumulated_grad, =  ctx.saved_tensors
        if ctx.mode == "accumulate":
            assert not ctx.needs_input_grad[0], "input tensors should not need gradient in accumulate mode"
            accumulated_grad += grad_output
            return None, None, None
        elif ctx.mode == "release":
            if not ctx.needs_input_grad[0]:
                return None, None, None
            if accumulated_grad is not None:
                accumulated_grad += grad_output
                accumulated_grad = grad_output
            grad_output = accumulated_grad
            return grad_output, None, None
            raise ValueError(f"invalid mode {ctx.mode}")

class GradAccumulator(nn.Module):
    Helper module used to accumulate gradient of the given tensor w.r.t output of criterion.
    Typically used when we have a feature extractor followed by several modules that can be calculate independently, the
    module only retains the last executed submodule and accumulate the gradient produced by former submodule, so that
    GPU memory used to store the temporary variables in former submodules is saved. It can be also used to extend
    effective batch size at little expense of memory.
    Note that criterion functions will NOT be registered as submodules of this module, which means that
    criterion functions CANNOT contain buffers or parameters because they will NOT be correctly processed by
    .cuda() method or DataParallel layer.
    def __init__(self, criterion_fns, submodules, collect_fn=None, reduce_method="mean"):
        super(GradAccumulator, self).__init__()
        assert isinstance(submodules, (Sized, Iterable)), "invalid submodules"
        if isinstance(criterion_fns, (Sized, Iterable)):
            assert len(submodules) == len(criterion_fns)
            assert all([isinstance(submodule, Callable) for submodule in submodules])
            assert all([isinstance(criterion_fn, Callable) for criterion_fn in criterion_fns])
        elif isinstance(criterion_fns, Callable):
            criterion_fns = [criterion_fns for _ in range(len(submodules))]
            raise ValueError("invalid criterion function")

        self.submodules = nn.ModuleList(submodules)
        self.criterion_fns = criterion_fns
        self.method = reduce_method
        self.grad_buffer = None
        self.func = GradAccumulatorFunction.apply 
        self.collect_fn = collect_fn

    def forward(self, tensor):
        outputs = []
        losses = 0
        self.grad_buffer = None
        for i, (submodule, criterion) in enumerate(zip(self.submodules, self.criterion_fns)):
            mode = "accumulate" if i < len(self.submodules) - 1 else "release"
            if self.grad_buffer is None:
                self.grad_buffer = torch.zeros_like(tensor)
            if mode == "accumulate":
            output = self.func(
            output = submodule(output)

            loss = criterion(output)
            if self.method == "mean":
                loss /= len(self.submodules)
            if mode == "accumulate":
            losses += loss

        if self.collect_fn is not None:
            outputs = self.collect_fn(outputs)

        return outputs, losses

Then, the work can be done like this:

class B(nn.Module):
    def __init__(self, net_B, ind_tuple):
        super(B, self).__init__()
        self.net_B = net_B
        self.ind_list = ind_tuple
    def forward(self, A):
        return self.net_B(A[self.ind_tuple[0], self.ind_tuple[1]:self.ind_tuple[1] + 1])

for epoch in range(arg.epochs):
    for i, data in enumerate(dataloader):
        # data = batch x 3 x 28 x 28
        A = net_A(data)
        # A = batch x N x 2
        net_B_ = GradAccumulator(
            [(lambda x: A_loss(x, data)) for _ in range(A.size(0) * A.size(1))],
            [B(net_B, (b, n)) for n in range(A.size(1)) for b in range(A.size(0))]
        outs, loss = net_B_(A)
        optimizer.step() # net_A parameters