Substitute nn.Modules in network

I am trying to prune a network. I tried getting the nn.Modules like nn.Conv2d etc. and just changed the conv.weight or other parameters etc. to a reduced tensor.

Eg. a Conv2d layer is pruned as below:

            for param in [prev.weight, prev.bias]:
       = self.remove(param, lc)

My way of removing a channel:

    def remove(self, param, lc, dim=0):
        if dim == 0:  # BN params & biases (any vector) OR prev conv (remove filter)
            tmp1 = param[:lc]
            tmp2 = param[lc+1:]

        elif dim == 1:  # next (remove channel)
            tmp1 = param[:, :lc, ...]  # out, in, H, W
            tmp2 = param[:, lc+1:, ...]  # out, in, H, W

        return[tmp1, tmp2], dim=dim)

During the forward pass it behaves OK and the output size matches the pruned size, but during backward pass it throws this error:

RuntimeError: Function CudnnConvolutionBackward returned an invalid gradient at index 1 - got [64, 127, 3, 3] but expected shape compatible with [64, 128, 3, 3]

By printing the network, I can see that nn.Conv2d still retain the old shapes.
Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

How should I change that particular nn.Conv2d? Specifically, after I created a brand-new nn.Conv2d and setting its parameters to the pruned, old-layer’s weights, how can I insert the new Conv2d at the correct part of my network? Or is there an easier way for pruning?

Here’s my repo and files. I mostly work in and



  • You should reset the .grad field of your weights and biais to None if you change their size to avoid any size problem when accumulating gradients.
  • You should rewrap your new weight into an nn.Parameter() before setting it so that it will be properly catched by things like model.parameters(). Also note that your pruning will break the previous model.parameters() so you need to do it again after the pruning.
  • You can monkey patch the convolution parameters directly into the instance you have (see the names here for what to change) or create a brand new one and replace the other nn.Module with things like model.conv1 = my_new_module.