Accessing modules - Custom ResNet18

I am using a ResNet-18 coded as follows:


class ResidualBlock(nn.Module):
    '''
    Residual Block within a ResNet CNN model
    '''
    def __init__(self, input_channels, num_channels, 
                 use_1x1_conv = False, strides = 1):
        # super(ResidualBlock, self).__init__()
        super().__init__()
     
        self.conv1 = nn.Conv2d(
            in_channels = input_channels, out_channels = num_channels,
            kernel_size = 3, padding = 1, stride = strides,
            bias = False
            )
        self.bn1 = nn.BatchNorm2d(num_features = num_channels)
        
        self.conv2 = nn.Conv2d(
            in_channels = num_channels, out_channels = num_channels,
            kernel_size = 3, padding = 1, stride = 1,
            bias = False
            )
        self.bn2 = nn.BatchNorm2d(num_features = num_channels)
        
        if use_1x1_conv:
            self.conv3 = nn.Conv2d(
                in_channels = input_channels, out_channels = num_channels,
                kernel_size = 1, stride = strides
                )
            self.bn3 = nn.BatchNorm2d(num_features = num_channels)
        else:
            self.conv3 = None
        
        self.relu = nn.ReLU(inplace = True)

        self.initialize_weights()
        
    
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        
        if self.conv3:
            X = self.bn3(self.conv3(X))
            # print(f"X.shape due to 1x1: {X.shape} & Y.shape = {Y.shape}")
        else:
            # print(f"X.shape without 1x1: {X.shape} & Y.shape = {Y.shape}")
            pass
        
        Y += X
        return F.relu(Y)
    
    
    def shape_computation(self, X):
        Y = self.conv1(X)
        print(f"self.conv1(X).shape: {Y.shape}")
        Y = self.conv2(Y)
        print(f"self.conv2(X).shape: {Y.shape}")
        
        if self.conv3:
            h = self.conv3(X)
            print(f"self.conv3(X).shape: {h.shape}")
    

    def initialize_weights(self):
        for m in self.modules():
            # print(m)
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight)

                '''
                # Do not initialize bias (due to batchnorm)-
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                '''
            
            elif isinstance(m, nn.BatchNorm2d):
                # Standard initialization for batch normalization-
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

b0 = nn.Sequential(
    nn.Conv2d(in_channels = 3, out_channels = 64, kernel_size = 3, stride = 1, padding = 1),
    nn.BatchNorm2d(num_features = 64),
    nn.ReLU())

def create_resnet_block(input_filters, output_filters, num_residuals, first_block = False):
    # Python list to hold the created ResNet blocks-
    resnet_blk = []
    
    for i in range(num_residuals):
        if i == 0 and first_block:
            resnet_blk.append(ResidualBlock(input_channels = input_filters, num_channels = output_filters, use_1x1_conv = True, strides = 2))
        else:
            resnet_blk.append(ResidualBlock(input_channels = output_filters, num_channels = output_filters, use_1x1_conv = False, strides = 1))
    
    return resnet_blk

b1 = nn.Sequential(*create_resnet_block(input_filters = 64, output_filters = 64, num_residuals = 2, first_block = True))

b2 = nn.Sequential(*create_resnet_block(input_filters = 64, output_filters = 128, num_residuals = 2, first_block = True))

b3 = nn.Sequential(*create_resnet_block(input_filters = 128, output_filters = 256, num_residuals = 2, first_block = True))

b4 = nn.Sequential(*create_resnet_block(input_filters = 256, output_filters = 512, num_residuals = 2, first_block = True))

# Initialize a ResNet-18 CNN model-
model = nn.Sequential(
    b0, b1, b2, b3, b4,
    nn.AdaptiveAvgPool2d(output_size = (1, 1)),
    nn.Flatten(),
    nn.Linear(in_features = 512, out_features = 10))

The layer names are now as follows:

for layer_name, param in trained_model.named_parameters():
    print(f"layer name: {layer_name} has {param.shape}")

layer name: 0.0.weight has torch.Size([64, 3, 3, 3])
layer name: 0.0.bias has torch.Size([64])
layer name: 0.1.weight has torch.Size([64])
layer name: 0.1.bias has torch.Size([64])
layer name: 1.0.conv1.weight has torch.Size([64, 64, 3, 3])
layer name: 1.0.bn1.weight has torch.Size([64])
layer name: 1.0.bn1.bias has torch.Size([64])
layer name: 1.0.conv2.weight has torch.Size([64, 64, 3, 3])
layer name: 1.0.bn2.weight has torch.Size([64])
layer name: 1.0.bn2.bias has torch.Size([64])
layer name: 1.0.conv3.weight has torch.Size([64, 64, 1, 1])
layer name: 1.0.conv3.bias has torch.Size([64])
layer name: 1.0.bn3.weight has torch.Size([64])
layer name: 1.0.bn3.bias has torch.Size([64])
layer name: 1.1.conv1.weight has torch.Size([64, 64, 3, 3])
layer name: 1.1.bn1.weight has torch.Size([64])
layer name: 1.1.bn1.bias has torch.Size([64])
layer name: 1.1.conv2.weight has torch.Size([64, 64, 3, 3])
layer name: 1.1.bn2.weight has torch.Size([64])
layer name: 1.1.bn2.bias has torch.Size([64])
layer name: 2.0.conv1.weight has torch.Size([128, 64, 3, 3])
layer name: 2.0.bn1.weight has torch.Size([128])
layer name: 2.0.bn1.bias has torch.Size([128])
layer name: 2.0.conv2.weight has torch.Size([128, 128, 3, 3])
layer name: 2.0.bn2.weight has torch.Size([128])
layer name: 2.0.bn2.bias has torch.Size([128])
layer name: 2.0.conv3.weight has torch.Size([128, 64, 1, 1])
layer name: 2.0.conv3.bias has torch.Size([128])
layer name: 2.0.bn3.weight has torch.Size([128])
layer name: 2.0.bn3.bias has torch.Size([128])
layer name: 2.1.conv1.weight has torch.Size([128, 128, 3, 3])
layer name: 2.1.bn1.weight has torch.Size([128])
layer name: 2.1.bn1.bias has torch.Size([128])
layer name: 2.1.conv2.weight has torch.Size([128, 128, 3, 3])
layer name: 2.1.bn2.weight has torch.Size([128])
layer name: 2.1.bn2.bias has torch.Size([128])
layer name: 3.0.conv1.weight has torch.Size([256, 128, 3, 3])
layer name: 3.0.bn1.weight has torch.Size([256])
layer name: 3.0.bn1.bias has torch.Size([256])
layer name: 3.0.conv2.weight has torch.Size([256, 256, 3, 3])
layer name: 3.0.bn2.weight has torch.Size([256])
layer name: 3.0.bn2.bias has torch.Size([256])
layer name: 3.0.conv3.weight has torch.Size([256, 128, 1, 1])
layer name: 3.0.conv3.bias has torch.Size([256])
layer name: 3.0.bn3.weight has torch.Size([256])
layer name: 3.0.bn3.bias has torch.Size([256])
layer name: 3.1.conv1.weight has torch.Size([256, 256, 3, 3])
layer name: 3.1.bn1.weight has torch.Size([256])
layer name: 3.1.bn1.bias has torch.Size([256])
layer name: 3.1.conv2.weight has torch.Size([256, 256, 3, 3])
layer name: 3.1.bn2.weight has torch.Size([256])
layer name: 3.1.bn2.bias has torch.Size([256])
layer name: 4.0.conv1.weight has torch.Size([512, 256, 3, 3])
layer name: 4.0.bn1.weight has torch.Size([512])
layer name: 4.0.bn1.bias has torch.Size([512])
layer name: 4.0.conv2.weight has torch.Size([512, 512, 3, 3])
layer name: 4.0.bn2.weight has torch.Size([512])
layer name: 4.0.bn2.bias has torch.Size([512])
layer name: 4.0.conv3.weight has torch.Size([512, 256, 1, 1])
layer name: 4.0.conv3.bias has torch.Size([512])
layer name: 4.0.bn3.weight has torch.Size([512])
layer name: 4.0.bn3.bias has torch.Size([512])
layer name: 4.1.conv1.weight has torch.Size([512, 512, 3, 3])
layer name: 4.1.bn1.weight has torch.Size([512])
layer name: 4.1.bn1.bias has torch.Size([512])
layer name: 4.1.conv2.weight has torch.Size([512, 512, 3, 3])
layer name: 4.1.bn2.weight has torch.Size([512])
layer name: 4.1.bn2.bias has torch.Size([512])
layer name: 7.weight has torch.Size([10, 512])
layer name: 7.bias has torch.Size([10])

In order to prune this model, I am referring to PyTorch pruning tutorial. It’s mentioned here that to prune a module/layer, use the following code:

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

But for the code above, the modules/layers no longer have this naming convention. For example, to prune the first conv layer of this model:

layer name: 0.0.weight has torch.Size([64, 3, 3, 3])

on trying the following code:

prune.random_unstructured(model.0.0, name = 'weight', amount = 0.3)

It gives me the error:

prune.random_unstructured(trained_model.0.0, name = ‘weight’, amount = 0.3)
^
SyntaxError: invalid syntax

How do I handle this?

You cannot acess number attributes
you have to do
trained_model._modules['0'] instead of trained_model.0

1 Like

Any way to write this in a for loop? Because for ResNet-152 or so, adding “prune” method to each conv layer inidividually doesn’t scale well.

You can iterate over the parameters and filter.
Like

for name,param in model.named_parameters():
     if name==condition:
       prune
1 Like