Custom layer's weights does not update

Hi,

Recently I’m doing some research on network pruning and built a custom layer for my vgg network.
My custom layer Mask contains a set of parameters that have the same shape as the input feature map. However, after I implemented the module together with my other conv layers using nn.Sequential, the parameters self.mask does not get updated.

I think perhaps my custom module broke some rules in Pytorch, can you please give me some advice on getting it work?

My custom layer is defined as following:

class Mask(nn.Module):
    def __init__(self,bits=1):
        super(Mask, self).__init__()
        self.mask = None
        self.bits = bits

        self.weight_shape = None
        self.channel = None
        self.height = None
        self.width = None

    def forward(self,input):
        if type(self.mask) == type(None):
            self.mask = Parameter(torch.Tensor(input.shape[1:]).cuda())
            nn.init.constant_(self.mask,val=1.0)
            self.weight_shape = input.shape[1:]
            self.channel = input.shape[1]
            self.height = input.shape[2]
            self.width = input.shape[3]
        result = input*self.mask
        return result

the Mask layers are built using:

def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 3
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            mask = utils.Mask(bits=1).cuda()
            conv2d = utils.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [mask,conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [mask,conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)

I initially defined the self.mask as None and initialized it in the forward() function because my intention was to mask the input of the layer, so I cannot fetch the size of input during initializing the module. But I suspect there might be something wrong here. If so, is there any other way I can do it correctly? Thanks in advance!

If you’ve passed all parameters to the optimizer after initializing the model via:

optimizer = optim.SGD(model.parameters(), lr=1e-3)

the self.mask parameter won’t be included and thus not updated, since it wasn’t initialized yet.
Could you pass the desired input shape to the initialization of the model and create self.mask in its __init__ method?
Based on your code it seems that self.mask is initialized to a static shape in the first forward pass.

2 Likes

Thank you for your reply! Your answer confirmed my doubt on my improper init of the weights. But is there any workaround in my situation? Since the shape of self.mask depends on the feed of the dataset and depth of the previous layer. I’m don’t think I’ll be able to define its shape during __init__.

Would it be possible to feed a single batch to the model in order to initialize the mask and then pass the parameters to the optimizer?

1 Like

Yes it worked, you saved my day bro! I added a single iteration without training before the definition of optimizer, and it can now properly update now.

def mask_init(train_loader, model):
    print("Initializing mask shape...")
    for i, (input, target) in enumerate(train_loader):
        if args.cpu == False:
            input = input.cuda(async=True)
            target = target.cuda(async=True)
        output = model(input)
        break
    print("Mask shape initiated!")

That sounds great! :slight_smile:

1 Like

Also there is a special case worth mentioning, when you use code with torch.nn.DataParallel(model) or other methods involving multiprocessing or multi-GPU, you want to do that single iteration of model.forward() before model.cuda() and related codes for distributed training. Otherwise the new initialized tensor won’t be actually assigned to the attribute you want.