Performance issues with nn.ModuleList() and growing architectures

PyTorch version: 2.1

I am building growing networks, i.e., models that add new parameters as needed during training. The problem is that the shorter, more elegant solution seems to have a memory leak and also performance issues (much slower).

One of the possible implementations in PyTorch, let’s call it A, is to use nn.ModuleList() and simply add new layers to that list. The entire network architecture is, of course, more complex but a simple example would be:

class GrowingLayer(nn.Module):
    ...
    self.layer_list = nn.ModuleList()
    ...

    def forward(self, x):
        layer_activations = [layer(x) for layer in self.layer_list]
        return torch.stack(layer_activations, dim=0).sum(dim=0)

    def grow(self):
        self.layer_list.append(nn.Conv2d(...))

Another possible implementation, let’s refer to that as B, is to have just a single Conv2d layer and, for growing, simply creating a new, bigger Conv2d layer and copy over the weights of the previously used layer to not disregard learned information. A simple example might be:

class GrowingConv2d(nn.Module):
    def __init__(self, ...):
        ...
        self.conv = nn.Conv2d(...)

    def grow(self, new_out_channels):
        new_conv = nn.Conv2d(self.conv.in_channels, new_out_channels, ...)

        # copy weights
        with torch.no_grad():
            new_conv.weight[:self.conv.out_channels] = self.conv.weight
            if self.conv.bias is not None:
                new_conv.bias[:self.conv.out_channels] = self.conv.bias

        # replace old with new layer
        self.conv = new_conv

    def forward(self, x):
        return self.conv(x)

Approach A should be preferred for experiments since in this scenario we are dealing with individual layers packed in module lists, arguably an elegant way of using module lists, and thus we can easily store each one of those in different optimizer groups. For example, when we add new parameters on growing:

        self.optim.param_groups.append(
            {
                "params": new_params,
                 ...
            }
        )

This is very handy because with just a few lines of code we can (1) grow different parts of a large architecture (2) keep momentum, learning rate, weight decay, … for “old” weights, and (3) initialize new params in a new optimizer group as we see fit for the current ongoing training.

Achieving the same with indexing every single layer by copying weights like in solution B is much more challenging since (1) we would need to store all of that information per indexed weight range and (2) and indexed weight is not the same as a nn.Parameter, thus we lose a lot of autograd features and need to write a lot more code manually.

Moreover, it is intuitive to use container structures to store new layers/parameters and requires significantly less lines of code in an entire training setup. Indexing weights manually is more prone to bugs, less intuitive, and requires a lot of manual implementation at other parts in the code (such as optimizer).

In terms of convergence and behavior during training approach A performs very well and as expected during training but w.r.t. performance (speed of forward + backward passes) and especially memory usage it is roughly an order of magnitude worse. Approach B performs as you would expect: as any regular Conv2d layer.

Now, especially an aggregation of results in pure Python (list comprehension) like this

layer_activations = [layer(x) for layer in self.layer_list]

should be slower but I do not see any reason why this approach would be roughly an order of magnitude slower.

What is even worse: there seems to be a memory leak with approach A. In approach B the VRAM usage increases as expected (close to linearly to the amount of weights you add). Approach A, however, when started with 800K parameters total for an entire model and growing it until 2M parameters goes from 3-4GB VRAM usage up to 40GB VRAM usage. For each call of the grow() method, in addition to the parameters that actually get added, there seems to be a very big overhead with this approach.

Did someone encounter similar issues?

Is there maybe a solution that is similar in simplicity to A but faster / more memory efficient?

It would actually be very handy to have a map-like or foreach-like function for nn.ModuleList that can apply some input to all layers in that list.

Apart from all of that, torch.compile also did not help. Performance for A actually did not increase at all and memory usage was the same. Although I did not try yet to re-compile after every growing operation, will test that also soon.

I think the issue is that approach A grows the autograd graph more than approach B because the activations will be saved in between the stack and sum, whereas there is no corresponding intermediate activation that will use memory in approach B.
approach A:
e.g., conv1 → out1, conv2 → out2
out1 and out2 will be saved for autograd
stack = [out1, out2]
stack will be saved for autograd
sum = stack …
sum will be saved for autograd

approach B:
conv → out
out will be saved for autograd

You may check if you can do a torch.concat in approach A and simply grow the number of output channels, but this will also incur a copy (though perhaps less than that what is currently done).

Thank you for the reply and the suggestion @eqy.

You may check if you can do a torch.concat

Interestingly, torch.concat actually did not improve the situation (at least not beyond possible measurement errors). I will try to write up a minimal example soon since testing with a big model and framework (1) takes too much time, (2) impact on performance might be concealed by other things happening, and (3) reproducability and testing should be better (especially for other people) with a minimal example.

For Approach A, you can make it much much more memory-efficient if you rewrite it like this:

def forward(self, x):
	out = None
	for layer in self.layer_list:
		out_layer = layer(x)
		if out is None:
			out = out_layer
		else:
			out.add_(out_layer)
	return out

@smth ah yes, of course, making it an in-place operation and adding up the values on the existing Tensor after first iteration instead of computing every single result and storing it separately in memory before adding up. Very good, thank you :slight_smile: