Is it possible to prune on a pretrained model using Pytorch Pruning API?


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

In this tutorial, the author discussed various types of pruning (structured, unstructured, local, global), but all those techniques were employed on a vanilla LeNet model. Is it possible to prune on a pretrained model using those builtin API. [Search — PyTorch 1.9.0 documentation](Pytorch Pruning API)

If yes, then how?

For example,
models.resnet50(pretrained = True)

Do I need to assign pretrained = False?
or
Do I need to write those pretrained models from scratch to apply pruning techniques? Or something else?

You can iterate over the modules of the pretrained model and prune each module one by one. The same tutorial has this example script:

new_model = LeNet()  # Replace with your model
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

However, I am not sure if you can gain anything (such as less memory, faster inference) with pruning, especially if you use the unstructured one. I had high expectations in the beginning, but was disappointed with what I got.