Dynamically Expandable Networks - Resizing Weights during Training

Hey,
I’m trying to reproduce something along the lines of the recent ICLR18 paper “Lifelong Learning with Dynamically Expandable Networks”: https://arxiv.org/abs/1708.01547 where the idea is to add units in-place to a trained network while further training.

I have a decent amount of what I think is functional code. However, there is two big questions I have left:

1.) When I resize the weights of a convolutional layer to let’s say add “x” filter, I will have to apply resize_() twice (note that I left out the bias for simplicity which also needs to be resized):

a) in my current layer i’s output dimensionality: layer[i].weight.data.resize_(layer[i].weight.data.size(0) + x, layer[i].weight.data.size(1), layer[i].weight.data.size(2), layer[i].weight.data.size(3))
b) in the consequent layer (i+1)'s input dimensionality:
layer[i+1].weight.data.resize_(layer[i+1].weight.data.size(0), layer[i+1].weight.data.size(1) + x, layer[i+1].weight.data.size(2), layer[i+1].weight.data.size(3))

Followed by some initialization of whatever has just been added.

Now executing a) doesn’t seem be difficult and I believe it can be done in-place as is. Given the underlying storage (If I understand correctly flattened 1-d C-array) this resize_() operation will simply append to the array. Thus the previous information should be preserved and we can just initialize newly added units by slicing.

Doing operation b) however seems to be difficult as the arrangement of information in the tensor isn’t preserved.
Currently I am dealing with this issue by doing an entire clone() of the weight, and then copying the corresponding slice back. Needless to say this is pretty bad in terms of memory and if necessary to copy back to CPU also in terms of time. Is there any other more efficient way of doing such a thing?

2.) The operations done in 1. don’t automatically seem to update shapes of gradient and pass information to autograd.

I thus do the same resizing operations in respective dimensions of layer[i].weight.grad.data to make sure the backward will match the forward pass.

Now this alone doesn’t seem to be enough during training and it seems I also need to create a new instance of my optimizer every time I change some dimensionality in the graph in order to update its parameters.
The above code is therefore always followed by a new instance of optimizer = torch.optim.SGD(...). I guess I could also resize the parameters here instead if it matters.

Empirically this gradient resizing and new optimizer instance seem to work, but I would like to know if anyone could tell me whether there is any pitfalls I am not thinking of? Is this implementation correct? Alternatively, is there any other way of doing this?

I think these questions are somewhat in-depth and it is hard for me to further figure out whether such an implementation is correct beyond empirical observations. So I would really appreciate any feedback!

PS: If anyone is interest I implemented the resizing in a forward_pre_hook. But I don’t think it affects the questions.

3 Likes

To 1):
The code itself looks good. Why do you have to copy back to CPU and then to GPU again?
Also note that .resize_ doesnt initialize the values, so maybe you would like to fill the new filters with any values? I’m not familiar with your method, so maybe your method will be sufficient.
I think the code might be easier, if you use the functional API.
I created a very simple example just for one layer:

# Setup
weight1 = torch.randn(6, 1, 3, 3, requires_grad=True)
x = torch.randn(1, 1, 10, 10)
optimizer = optim.SGD([weight1], lr=1e-3)

# First training step
output = F.conv2d(x, weight1, bias=None, stride=1, padding=1)
output.mean().backward()
print(weight1.grad)
optimizer.step()
optimizer.zero_grad() # not necessary if you change weight1 immediately

# Add additional filter and train
weight1 = torch.cat((weight1.data, torch.zeros(1, 1, 3, 3)), dim=0)
weight1.requires_grad_(True)
optimizer = optim.SGD([weight1], lr=1e-3)
output = F.conv2d(x, weight1, bias=None, stride=1, padding=1)
output.mean().backward()
print(weight1.grad)
optimizer.step()
optimizer.zero_grad()

To 2):
I think you end up recreating the optimizer for each weight change, if you use this approach.
At least I’m not aware of an easy approach of updating the optimizer.

Thanks for your feedback.

To 2) I guess this will just have to do then. If there isn’t any apparent pitfalls where this can go wrong then the overhead shouldn’t be too bad.

To 1). I think I should have given a direct example to make it clear. Your solution with functional API and torch.cat() looks nice but I don’t think it addresses the issue here as your example only deals with one layer (or I’m maybe totally just missing it? ;)). The challenge really comes in when dealing with layer/weight2. This is also why I need a copy. My bad for not giving a concise easy to reproduce example. So here it is:

For simplicity let’s assume these are the weights of two consecutive linear layers where we start with fan_in and fan_out = 1 in each layer (because the spatial dimensions don’t really contribute to the argument). I’m going to leave out the SGD parts because that seams to be solved.

weight1 = torch.randn(2,2, requires_grad=True) 
>>> weight1
tensor([[-1.1513, -1.6009],
        [ 0.5701,  0.4365]])
weight2 = torch.randn(2,2, requires_grad=True) 
>>> weight2
tensor([[ 0.7709,  1.9652],
        [ 1.7878, -0.3036]])

According to some criteria we now add a unit/filter to weight1. This is easy and I believe your cat solution could also be used here.

weight1.data.resize_(3,2)
tensor([[-1.1513e+00, -1.6009e+00],
        [ 5.7011e-01,  4.3649e-01],
        [ 8.8747e-29,  4.5771e-41]])
weight1[-1:,:].normal_()
>>> weight1
tensor([[-1.1513, -1.6009],
        [ 0.5701,  0.4365],
        [-0.5734,  0.3861]])

So as we can see we can use resize_() in dimension 0 without modifying previous values (it needs to be on .data as require_grad params can’t be resized). I had left out initialization for simplicity before. But of course you are right that we need to init the corresponding new slice as otherwise there will be junk values in the weights.

But thats only because we are resizing dimension 0! Now if you think of this as two following layers, this change in weight1 implies a change in weight2 in the other non-spatial dimension , because fan_out of one layer is fan_in of the other. And yes I know it might look confusing, but it looks like in PyTorch fan_out is actually dim=0 and fan_in is dim=1. If you think of it from a flattened point of view it makes sense though as fan_in would get flattened together with spatial dimensions dim=2,3 in computation of convolutions.

If we naively go and do the same code, giving that we are modifying dimension 1 instead of 0, we are going to mess up previous information:

>>> weight2
tensor([[ 0.7709,  1.9652],
        [ 1.7878, -0.3036]])
weight2.data.resize_(2,3)
>>> tensor([[ 7.7089e-01,  1.9652e+00,  1.7878e+00],
        [-3.0362e-01,  8.2830e-29,  4.5771e-41]])

It’s already quite clear here that the information is no longer the same (because of internally flattened representations probably). The weights are not preserved here.
So the only work-around I see for this is to do something like this:

tmp = weight2.data.clone()
weight2.data.resize_(2,3)
>>> tensor([[ 7.7089e-01,  1.9652e+00,  1.7878e+00],
        [-3.0362e-01,  1.6501e+19,  4.0092e-08]])
weight2.data[:,0:-1] = tmp
>>> weight2
tensor([[ 7.7089e-01,  1.9652e+00,  1.7878e+00],
        [ 1.7878e+00, -3.0362e-01,  4.0092e-08]])
weight2.data[:,-1:].normal_()
>>> weight2
tensor([[ 0.7709,  1.9652,  0.8498],
        [ 1.7878, -0.3036,  0.2938]])

Which is what we actually want. So here you can see why I need a temporary copy and why I transfer to CPU. Because if you imagine actually talking about a VGG/ResNet/Segnet or some other memory intense network, I wouldn’t really want to clone the entire weights just to resize the weights.

I hope this makes it more clear. I will appreciate any further feedback!

1 Like

I’ve been thinking about your suggested solution again and it does actually solve my issue :). I just didn’t think the torch.cat() operation through properly and have been too stuck on the resize_() for some reason. It works with both functional and non-functional API.

so to have an example that works in my previous code’s style without the explicit copy operation for anyone reading this post:

weight1 = torch.randn(2,2,requires_grad=True) 
weight2 = torch.randn(2,2,requires_grad=True)
>>> weight1
tensor([[ 0.5209, -0.7963],
        [ 1.9585,  0.3662]])
>>> weight2
tensor([[-0.9158,  0.3723],
        [-1.1794,  0.1901]])

weight1.data = torch.cat((weight1.data, torch.randn(1,2)), dim=0)
>>> weight1
tensor([[ 0.5209, -0.7963],
        [ 1.9585,  0.3662],
        [ 0.6279,  2.3723]])
weight2.data = torch.cat((weight2.data, torch.randn(2,1)), dim=1)
>>> weight2
tensor([[-0.9158,  0.3723,  1.0036],
        [-1.1794,  0.1901, -0.6061]])

Thanks again @ptrblck =)

1 Like

I’m glad it worked after all!
My code was quite messy, because I tried out several approaches and apparently the functional example stayed quite simple. :wink:

Another idea I had after reading your first response it to separate the convolutions.
It’s probably not meaningful anymore, since the code works, but maybe still interesting.

Since each filter does not depend on the other filters, we could apply different convolutions and sum the result together. It’s a bit tricky to explain, so I created a small example.
The filters named convX_2 are the additional ones. We would have to slice the output of the previous layer, which doesn’t seem to be that good. I think the performance might suffer a bit.

# Normal model
conv1_1 = nn.Conv2d(3, 6, 3, 1, 1, bias=False)
conv2_1 = nn.Conv2d(6, 12, 3, 1, 1, bias=False)
conv2_1.weight.data.fill_(0.5)

x = torch.randn(1, 3, 4, 4)
output = conv1_1(x)
output = conv2_1(output)

# Now let's add an additional conv
conv1_2 = nn.Conv2d(3, 1, 3, 1, 1, bias=False)
conv2_2 = nn.Conv2d(1, 12, 3, 1, 1, bias=False)
conv2_2.weight.data.fill_(1.0)

# First layer
output1_1 = conv1_1(x)
output1_2 = conv1_2(x)

output1 = torch.cat((output1_1, output1_2), dim=1)

# Second layer
output2_1 = conv2_1(output1[:, :6])
output2_2 = conv2_2(output1[:, 6:7])
output2 = output2_1 + output2_2

# Compare with non-separated convolution
conv2_comp = nn.Conv2d(7, 12, 3, 1, 1, bias=False)
conv2_comp.weight.data[:, :6].fill_(0.5)
conv2_comp.weight.data[:, 6:7].fill_(1.0)

output2_comp = conv2_comp(output1)

print('Sum of abs error: {}'.format((output2 - output2_comp).abs().sum()))

You current code might be cleaner, and this approach is not completely thought out, but maybe it’s useful in some sense. :slight_smile:

After going through some testing I had to find out that while the solution now looks cleaner with the cat() instead of the resize_() as there is less lines, it still actually has the same memory problem.

Given that cat() is not an in-place operation if I do something like:

a = torch.cat((a,b), dim=0) 

it seems to internally create the memory copy now. To give an extreme example to illustrate the point: If I have a 11Gb GPU, a is 10 Gb in size, but b is only 10Mb, it runs out of memory in the cat operation.
This seems equivalent to the resize_() method that operates in-place and filling the resized tensor with a clone(). I couldn’t find a source implementation for cat(), but it looks like this is what it might do internally after all.

So unfortunately, my first question remains open. I will have to think further about your latest post, but it seems really difficult to implement if one repeats addition of units to the layer many times.

Yeah, you are right. It’s quite complicated, since the shapes will change a lot during the operations.
I’ve created a small example, which unfortunately doesn’t cover all your use cases.
You can add additional layers, but only multiple times for the last one.
Maybe there is a cleaner solution, but currently it looks like we would have to create always “helper” convolutions to match the dimensions.
Here is the code. It looks a bit complicated so sorry for that:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 6, 3, 1, 1, bias=False)),
            ('conv2', nn.Conv2d(6, 12, 3, 1, 1, bias=False)),
            ('conv3', nn.Conv2d(12, 24, 3, 1, 1, bias=False)),
        ]))
        self.add_layers = nn.Sequential(OrderedDict([
            ('conv1', nn.ModuleList([])),
            ('conv2', nn.ModuleList([])),
            ('conv3', nn.ModuleList([])),
        ]))
        self.add_layers_help = nn.Sequential(OrderedDict([
            ('conv1', nn.ModuleList([])),
            ('conv2', nn.ModuleList([])),
            ('conv3', nn.ModuleList([])),
        ]))
    

    def expand_layer(self, layer_number, num_filters):
        # Create new conv layer for current layer
        in_channels0 = self.layers[layer_number].in_channels
        if layer_number > 0:
            in_channels0 += sum([add_layer.in_channels for add_layer in self.add_layers_help[layer_number]])
        out_channels0 = num_filters
        kernel_size0 = self.layers[layer_number].kernel_size
        layer0 = nn.Conv2d(in_channels0,
                           out_channels0,
                           kernel_size0,
                           stride=1,
                           padding=1,
                           bias=False)
        new_idx = len(self.add_layers[layer_number])
        self.add_layers[layer_number].add_module(
            str(new_idx), layer0)
        
        # Check if next layer needs modification
        if layer_number < (len(self.layers) - 1):
            in_channels1 = num_filters
            out_channels1 = self.layers[layer_number + 1].out_channels
            kernel_size1 = self.layers[layer_number + 1].kernel_size
            layer1 = nn.Conv2d(in_channels1,
                               out_channels1,
                               kernel_size1,
                               stride=1,
                               padding=1,
                               bias=False)
            new_idx = len(self.add_layers_help[layer_number + 1])
            self.add_layers_help[layer_number + 1].add_module(
                str(new_idx), layer1)

    def forward(self, x):
        # Iterate all layers
        for idx in range(len(self.layers)):
            layer = self.layers[idx]
            add_layer = self.add_layers[idx]
            add_layer_help = self.add_layers_help[idx]

            # Helper flag
            needs_layer = True

            # Check if "helper" layers are needed .
            # This is the case, if the previous output is bigger now.
            if len(add_layer_help) != 0:
                # Split the tensor so that the original layer and the helper
                # layers can perform the convolution together.
                split_idx = [layer.in_channels] + [l.in_channels for l in add_layer_help]
                x_split = torch.split(x, split_idx, dim=1)

                # Apply the original layer on the first split, since all
                # additional slices are appended at the end.
                x_out = layer(x_split[0])
                needs_layer = False
                # Iterate all helper layers and sum the outputs
                for x_, l in zip(x_split[1:], add_layer_help):
                    x_out = x_out + l(x_)

            # Check if any additional layers are needed
            if len(add_layer) != 0:
                x1 = torch.cat([l(x) for l in add_layer], dim=1)
                if needs_layer:
                    x_out = layer(x)
                    needs_layer = False
                x_out = torch.cat((x_out, x1), dim=1)

            # Check if no op was performed yet
            if needs_layer:
                x_out = layer(x)

            x = x_out

        return x


model = MyModel()
x = torch.randn(1, 3, 4, 4)    
output = model(x)
print(output.shape)

model.expand_layer(layer_number=0, num_filters=2)
output = model(x)
print(output.shape)

model.expand_layer(layer_number=1, num_filters=2)
output = model(x)
print(output.shape)

model.expand_layer(layer_number=2, num_filters=2)
output = model(x)
print(output.shape)

# Add another layer
model.expand_layer(layer_number=2, num_filters=2)
output = model(x)
print(output.shape)
2 Likes