Parameters are not updated in my custom model

1.I made my custom model, AlexNetQIL (Alexnet with QIL layer) ‘QIL’ means quantization intervals learning
2. I trained my model and loss value didn’t decrease at all and I found out parameters in my model were not updated at all because of QIL layer I added
3. I attached my codes AlexNetQil and qil please someone let me know what’s the problem in my codes

AlexNetQIL

import torch
import torch.nn as nn
from qil import *

class AlexNetQIL(nn.Module):

    #def __init__(self, num_classes=1000): for imagenet
    def __init__(self, num_classes=10): # for cifar-10
        super(AlexNetQIL, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.qil2 = Qil()
        self.conv2 = nn.Conv2d(64, 192, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(192)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.qil3 = Qil()
        self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(384)
        self.relu3 = nn.ReLU(inplace=True)

        self.qil4 = Qil()
        self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU(inplace=True)

        self.qil5 = Qil()
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.ReLU(inplace=True)
        self.maxpool5 = nn.MaxPool2d(kernel_size=2)

        self.classifier = nn.Sequential(
            nn.Linear(256 * 2 * 2, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
    def forward(self,x,inference = False):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu2(x)
        x = self.maxpool1(x)

        x,self.conv2.weight = self.qil2(x,self.conv2.weight,inference ) # if I remove this line, No problem 
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x,self.conv3.weight = self.qil3(x,self.conv3.weight,inference ) # if I remove this line, No problem 
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        x,self.conv4.weight = self.qil4(x,self.conv4.weight,inference ) # if I remove this line, No problem 
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu4(x)

        x,self.conv5.weight = self.qil5(x,self.conv5.weight,inference ) # if I remove this line, No problem 
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.relu5(x)
        x = self.maxpool5(x)
        x = x.view(x.size(0),256 * 2 * 2)
        x = self.classifier(x)
        return x

QIL
forward

  • quantize weights and input activation with 2 steps
  • transformer(params) -> discretizer(params)
import torch
import torch.nn as nn
import numpy as np
import copy

#Qil (Quantize intervals learning)
class Qil(nn.Module):

    discretization_level = 32
    def __init__(self):
        super(Qil,self).__init__()
        self.cw = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.dw = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.cx = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.dx = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.gamma = nn.Parameter(torch.tensor(1.0))  # I have to train this transformer parameter

        self.a = Qil.discretization_level
    def forward(self,x,weights,Inference = False):
        if not Inference:
            weights = self.transfomer_weights(weights)
            weights = self.discretizer(weights)
        x = self.transfomer_activation(x)
        x = self.discretizer(x)
        return torch.nn.Parameter(x), torch.nn.Parameter(weights)

    def transfomer_weights(self,weights):
        device = weights.device
        aw,bw = (0.5 / self.dw) , (-0.5*self.cw / self.dw + 0.5)

        weights = torch.where( abs(weights) < self.cw - self.dw,
                                torch.tensor(0.).to(device),weights)
        weights = torch.where( abs(weights) > self.cw + self.dw,
                                weights.sign(), weights)
        weights = torch.where( (abs(weights) >= self.cw - self.dw) & (abs(weights) <= self.cw + self.dw),
                                (aw*abs(weights) + bw)**self.gamma * weights.sign() , weights)
        return weights

    def transfomer_activation(self,x):
        device = x.device
        ax,bx = (0.5 / self.dx) , (-0.5*self.cx / self.dx + 0.5)

        x = torch.where(x < self.cx - self.dx,
                        torch.tensor(0.).to(device),x)
        x = torch.where(x > self.cx + self.dx,
                        torch.tensor(1.0).to(device),x)
        x = torch.where( (abs(x) >= self.cx - self.dx) & (abs(x) <= self.cx + self.dx),
                            ax*abs(x) + bx, x)
        return x

    def discretizer(self,tensor):
        q_D = pow(2, Qil.discretization_level)
        tensor = torch.round(tensor * q_D) / q_D
        return tensor

Hi,

The problem is with your QIL() layer.
Why do you wrap the result into a torch.nn.Parameter(x)? nn.Parameters are only to store parameters created during initialization of nn.Modules.

Also, if you don’t want to keep the original weights of the convolutions and replace them at forward by the ones you compute, you should do del self.conv5.weight in the init so that these weight are not considered learnable parameters and you will be able to set it during the forward.

Thank you,
yes, I thought nn.Parameter(x) and nn.Parameter(weight) are problem but If I don’t wrap those parameters with nn.Parameter( ) , I got error like

TypeError: cannot assign ‘torch.cuda.FloatTensor’ as parameter ‘weight’ (torch.nn.Parameter or None expected)

Because torch.where() in transformer function returns torch.cuda.FloatTensor
I don’t know how to I solve this problem

To access the underlying tensor of nn.Parameter, use the data attribute:

weight.data = self.transfomer_weights(weights.data)

This will replace the value of weight while still making the nn.Parameter instance registered in the module.
Is this what you are looking for?

You should not do that as it will break the computational graph and you will get wrong gradients !!

@Suyoung_Park check the second part of my message. That explains how to avoid the issue with trying to change the nn.Parameter. In particular, you need to delete the unused weights in the init.

You’re right, I didn’t notice that there are additional learnable parameters used in the quantization which need gradients.

Then I think the most clean way to do so is to just customize your own Conv2d class, which in the forward function just takes the quantized weights as the argument of F.conv2d.

Something like:

class Qil_Conv2d(nn.Conv2d):

    discretization_level = 32
    def __init__(self, *args, **kwargs):
        super(Qil_Conv2d,self).__init__(*args, **kwargs)
        self.cw = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.dw = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.cx = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.dx = nn.Parameter(torch.rand(1)) # I have to train this interval parameter
        self.gamma = nn.Parameter(torch.tensor(1.0))  # I have to train this transformer parameter

        self.a = Qil_Conv2d.discretization_level
    def forward(self,x,Inference = False):
        if not Inference:
            w = self.transfomer_weights(self.weights)
            w = self.discretizer(w)
        x = self.transfomer_activation(x)
        x = self.discretizer(x)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def transfomer_weights(self,weights):
        device = weights.device
        aw,bw = (0.5 / self.dw) , (-0.5*self.cw / self.dw + 0.5)

        weights = torch.where( abs(weights) < self.cw - self.dw,
                                torch.tensor(0.).to(device),weights)
        weights = torch.where( abs(weights) > self.cw + self.dw,
                                weights.sign(), weights)
        weights = torch.where( (abs(weights) >= self.cw - self.dw) & (abs(weights) <= self.cw + self.dw),
                                (aw*abs(weights) + bw)**self.gamma * weights.sign() , weights)
        return weights

    def transfomer_activation(self,x):
        device = x.device
        ax,bx = (0.5 / self.dx) , (-0.5*self.cx / self.dx + 0.5)

        x = torch.where(x < self.cx - self.dx,
                        torch.tensor(0.).to(device),x)
        x = torch.where(x > self.cx + self.dx,
                        torch.tensor(1.0).to(device),x)
        x = torch.where( (abs(x) >= self.cx - self.dx) & (abs(x) <= self.cx + self.dx),
                            ax*abs(x) + bx, x)
        return x

    def discretizer(self,tensor):
        q_D = pow(2, Qil_Conv2d.discretization_level)
        tensor = torch.round(tensor * q_D) / q_D
        return tensor

should work well replacing the role of Qil and nn.Conv2d (not tested).

Now I understand it, If I wrap parameters with nn.Parameter() or use .data, I will get wrong weird gradients. but what do you mean I have to delete the unused weights( only conv5.weight?) in the init?

Thank you,
It worked well!!
but I had to modify some of part to get a correct gradient

# before
weights = torch.where( abs(weights) < self.cw - self.dw, torch.tensor(0.).to(device),weights)
# after
weights_out = torch.where( abs(weights) < self.cw - self.dw, torch.tensor(0.).to(device),weights)

in this way
Thank you very much!!!