I have been working for a week trying to solve the same problem, and I think I came up with a possible solution.
First, you shoud define the following function and module:
class GradAccumulatorFunction(Function):
@staticmethod
def forward(ctx, input, accumulated_grad=None, mode="release"):
ctx.save_for_backward(accumulated_grad)
ctx.mode = mode
return input
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
accumulated_grad, = ctx.saved_tensors
if ctx.mode == "accumulate":
assert not ctx.needs_input_grad[0], "input tensors should not need gradient in accumulate mode"
accumulated_grad += grad_output
return None, None, None
elif ctx.mode == "release":
if not ctx.needs_input_grad[0]:
return None, None, None
if accumulated_grad is not None:
accumulated_grad += grad_output
else:
accumulated_grad = grad_output
grad_output = accumulated_grad
return grad_output, None, None
else:
raise ValueError(f"invalid mode {ctx.mode}")
class GradAccumulator(nn.Module):
"""
Helper module used to accumulate gradient of the given tensor w.r.t output of criterion.
Typically used when we have a feature extractor followed by several modules that can be calculate independently, the
module only retains the last executed submodule and accumulate the gradient produced by former submodule, so that
GPU memory used to store the temporary variables in former submodules is saved. It can be also used to extend
effective batch size at little expense of memory.
Note that criterion functions will NOT be registered as submodules of this module, which means that
criterion functions CANNOT contain buffers or parameters because they will NOT be correctly processed by
.cuda() method or DataParallel layer.
"""
def __init__(self, criterion_fns, submodules, collect_fn=None, reduce_method="mean"):
super(GradAccumulator, self).__init__()
assert isinstance(submodules, (Sized, Iterable)), "invalid submodules"
if isinstance(criterion_fns, (Sized, Iterable)):
assert len(submodules) == len(criterion_fns)
assert all([isinstance(submodule, Callable) for submodule in submodules])
assert all([isinstance(criterion_fn, Callable) for criterion_fn in criterion_fns])
elif isinstance(criterion_fns, Callable):
criterion_fns = [criterion_fns for _ in range(len(submodules))]
else:
raise ValueError("invalid criterion function")
self.submodules = nn.ModuleList(submodules)
self.criterion_fns = criterion_fns
self.method = reduce_method
self.grad_buffer = None
self.func = GradAccumulatorFunction.apply
self.collect_fn = collect_fn
def forward(self, tensor):
outputs = []
losses = 0
self.grad_buffer = None
for i, (submodule, criterion) in enumerate(zip(self.submodules, self.criterion_fns)):
mode = "accumulate" if i < len(self.submodules) - 1 else "release"
if self.grad_buffer is None:
self.grad_buffer = torch.zeros_like(tensor)
if mode == "accumulate":
tensor.detach_()
output = self.func(
tensor,
accumulated_grad=self.grad_buffer,
mode=mode,
)
output = submodule(output)
loss = criterion(output)
if self.method == "mean":
loss /= len(self.submodules)
if mode == "accumulate":
loss.backward()
loss.detach_()
output.detach_()
outputs.append(output)
losses += loss
if self.collect_fn is not None:
outputs = self.collect_fn(outputs)
return outputs, losses
Then, the work can be done like this:
class B(nn.Module):
def __init__(self, net_B, ind_tuple):
super(B, self).__init__()
self.net_B = net_B
self.ind_list = ind_tuple
def forward(self, A):
return self.net_B(A[self.ind_tuple[0], self.ind_tuple[1]:self.ind_tuple[1] + 1])
for epoch in range(arg.epochs):
for i, data in enumerate(dataloader):
net_A.zero_grad()
# data = batch x 3 x 28 x 28
A = net_A(data)
# A = batch x N x 2
net_B_ = GradAccumulator(
[(lambda x: A_loss(x, data)) for _ in range(A.size(0) * A.size(1))],
[B(net_B, (b, n)) for n in range(A.size(1)) for b in range(A.size(0))]
)
outs, loss = net_B_(A)
loss.backward()
optimizer.step() # net_A parameters