Retain sub-graph

Hi,

I have net_A (+) which I want to train, and net_B (-) which is pre-trained and has requires_grad=False.
net_C is as follows:

+++++++++++±----------
net_A -> net_B (iterate N times)

Sample code:

``````for epoch in range(arg.epochs):
# data = batch x 3 x 28 x 28
A = net_A(data)
# A = batch x N x 2
for b in range(A.size(0)):
for n in range(A.size(1)):
B = net_B(A[b, n:n+1])
loss = A_loss(B, data) # some loss that involves net_B and input data
loss.backward(retain_graph=True)
optimizer.step() # net_A parameters
``````

In the computation graph net_A (+) is trained and needed for every backprop, but net_B (-) branches (leafs) are changed.
retain_graph saves all (-) paths, which are not needed (only the current leaf is needed) and ends up returning “out of memory”…

How can I retain only the sub-graph (+) ?

Thanks

I have been working for a week trying to solve the same problem, and I think I came up with a possible solution.
First, you shoud define the following function and module:

``````class GradAccumulatorFunction(Function):
@staticmethod
ctx.mode = mode
return input

@staticmethod
@once_differentiable
if ctx.mode == "accumulate":
assert not ctx.needs_input_grad[0], "input tensors should not need gradient in accumulate mode"
return None, None, None
elif ctx.mode == "release":
return None, None, None
else:
else:
raise ValueError(f"invalid mode {ctx.mode}")

"""
Helper module used to accumulate gradient of the given tensor w.r.t output of criterion.
Typically used when we have a feature extractor followed by several modules that can be calculate independently, the
module only retains the last executed submodule and accumulate the gradient produced by former submodule, so that
GPU memory used to store the temporary variables in former submodules is saved. It can be also used to extend
effective batch size at little expense of memory.
Note that criterion functions will NOT be registered as submodules of this module, which means that
criterion functions CANNOT contain buffers or parameters because they will NOT be correctly processed by
.cuda() method or DataParallel layer.
"""
def __init__(self, criterion_fns, submodules, collect_fn=None, reduce_method="mean"):
assert isinstance(submodules, (Sized, Iterable)), "invalid submodules"
if isinstance(criterion_fns, (Sized, Iterable)):
assert len(submodules) == len(criterion_fns)
assert all([isinstance(submodule, Callable) for submodule in submodules])
assert all([isinstance(criterion_fn, Callable) for criterion_fn in criterion_fns])
elif isinstance(criterion_fns, Callable):
criterion_fns = [criterion_fns for _ in range(len(submodules))]
else:
raise ValueError("invalid criterion function")

self.submodules = nn.ModuleList(submodules)
self.criterion_fns = criterion_fns
self.method = reduce_method
self.collect_fn = collect_fn

def forward(self, tensor):
outputs = []
losses = 0
for i, (submodule, criterion) in enumerate(zip(self.submodules, self.criterion_fns)):
mode = "accumulate" if i < len(self.submodules) - 1 else "release"
if mode == "accumulate":
tensor.detach_()
output = self.func(
tensor,
mode=mode,
)
output = submodule(output)

loss = criterion(output)
if self.method == "mean":
loss /= len(self.submodules)
if mode == "accumulate":
loss.backward()
loss.detach_()
output.detach_()
outputs.append(output)
losses += loss

if self.collect_fn is not None:
outputs = self.collect_fn(outputs)

return outputs, losses
``````

Then, the work can be done like this:

``````class B(nn.Module):
def __init__(self, net_B, ind_tuple):
super(B, self).__init__()
self.net_B = net_B
self.ind_list = ind_tuple
def forward(self, A):
return self.net_B(A[self.ind_tuple[0], self.ind_tuple[1]:self.ind_tuple[1] + 1])

for epoch in range(arg.epochs):