Discussion of practical speedup for pruning

So I have been doing some experimenting and I’ll post the results here for discussion purposes even though they are certainly not finished. It seems to me that using the masks to ignore unnecessary operations is probably the way to go. That way we have access to any type of structured “skipping” of operations without pointers or something along those lines.

I started dabbling with the Conv2d from torch.nn.modules.conv and decided to adapt the conv2d_forward() function. For testing I am pruning (structured L1 on convolutions) ResNet18 for CIFAR10. As a baseline, running my model on the test dataset normally takes about 3.00 seconds.

Firstly, to get an idea of the possible speedup, I changed the following:

class Conv2d(_ConvNd):
    ...
    def conv2d_forward(self, input, weight):
        if hasattr(self, 'weight_mask'):
            bsize = input.size()[0]
            xdim = input.size()[2] // self.stride[0]
            ydim = input.size()[3] // self.stride[1]
            return torch.cuda.FloatTensor(bsize, self.out_channels, xdim, ydim).fill_(0)
        ...

This code runs in about 1.30 seconds which serves as a current lower bound (of course this is a hacky way to create a tensor filled with zeroes of the correct size). Next I tried the following code which skips the operation if the entire layer is pruned:

class Conv2d(_ConvNd):
    ...
    def conv2d_forward(self, input, weight):
        if hasattr(self, 'weight_mask'):
            if self.weight_mask.sum() == 0:
                bsize = input.size()[0]
                xdim = input.size()[2] // self.stride[0]
                ydim = input.size()[3] // self.stride[1]
                return torch.cuda.FloatTensor(bsize, self.out_channels, xdim, ydim).fill_(0)
        ...

Of course this rarely ever happens. If I keep only 0.1% of the convolutions in ResNet it will run in 2.51 seconds which is a speed up but not much considering we are reducing the number of parameters by a 1000. Next I tried to do this for each output channel of the layer:

from .. import functional as F

class Conv2d(_ConvNd):
    ...
    def conv2d_forward(self, input, weight):
        if hasattr(self, 'weight_mask'):
            bsize = input.size()[0]
            xdim = input.size()[2] // self.stride[0]
            ydim = input.size()[3] // self.stride[1]
            return_tensor = torch.cuda.FloatTensor(bsize, self.out_channels, xdim, ydim).fill_(0)
            for co in range(self.out_channels):
                if self.weight_mask[co].sum() != 0:
                    return_tensor[:,co] = F.conv2d(input, weight[co][None,:,:,:], self.bias, self.stride, self.padding, self.dilation, self.groups)[:,0,:,:]
            return return_tensor

Understandably, this is very slow (about 100 times slower). Here my lack of PyTorch coding made me write this hacky code. If anyone has input on the following, it would be greatly appreciated:

  1. Is this approach viable or does anyone have any other ideas?
  2. Are there pre-existing PyTorch functions to obtain a tensor of zeros on the correct device, in the correct shape of the convolution (@ptrblck)? The way I have done it now will not hold up in all cases I think.
  3. How would one efficiently single out convolutions to be run on the GPU and assign them to the return tensor. Perhaps its easiest to start with grouping them together channelwise first as I did in the last example.

Any other type of input is also appreciated.

Richard

2 Likes