Setting number of bits for Weights and activations

I would like to change the number of bits for weights and activation during testing. The network is pretrained and .pth is available which is usually 32 bit as per my understanding. Lets say i have the following network which is already trained.

class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=3, stride=2),
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Linear(256 * 6 * 6, 4096),
            nn.Linear(4096, 4096),
            nn.Linear(4096, num_classes),

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = AlexNet()
state_dict = load_state_dict_from_url(model_urls['alexnet'],

For Testing:

I want to change the activation’s of this network to be 9-8-5-5-7 bits which are the Relu’s in self.features per my understanding and weights to be 11 bit in all convolution layers.

I also want to change to Fully Connected layers weigh to be 10-9-9 bit respectively.

Can someone please point me to the right direction or share a code snippet. Thanking in anticipation.


What are these numbers of bits? Floating points are usually 64, 32 or 16 bits. Do you refer to other encodings here?

Hello Thanks for your reply. They refer to quantized bits as per my understanding. [page-4]

Currently i have tried this solution [inside utee folder quant file]

arr = [11,11,11,11,11,11,11,11,11,11,10,10,9,9,9,9]
count = 0
bit_num = OrderedDict()

for k, v in state_dict.items():
    bit_num[k] = arr[count]    
    count += 1

state_dict_quant = OrderedDict()

#state_dict_quant['features.0.weight'] = quant.log_minmax_quantize(model.features[0].weight, bits=11)

for k, v in state_dict.items():
    state_dict_quant[k] = quant.log_minmax_quantize(state_dict[k], bits=bit_num[k])

Looks like it is doing the job, though i still have to think how to modify bits for activation, as these are calculated on run time ? . May be “hook” can help.

You can take a look at the quantization module to see if it fits your needs?

I am not sure how can this module help in setting custom number of bits, is it possible using quantization module ?.

Also can you please help me with a code snippet to change number of bits for the activation ? Currently i am using this one but it doesn’t seem to work.

arr1 = [9,8,5,5,7]     # activations quantized 
#arr1 = [2,2,2,2,2]

def quantizee(self,input,output):
    #print('Inside ' + self.__class__.__name__ + ' forward')
    #print('input: ', type(input))
    #print('output: ', type(output))
    print('output size:',
    #out = output.detach()
    #print('out: ', type(out))
    output = quant.min_max_quantize(,arr1[quantizee.counter])
    quantizee.counter += 1
    if (quantizee.counter == 5):
        quantizee.counter = 0

quantizee.counter = 0

count1 = 0          # number of relus you want to quantize count 1 is that number
for name, module in model.named_modules():
    if (isinstance (module, torch.nn.modules.Conv2d) & (count1 < 5)):
        count1 += 1'''

i have tried this one too but no success.

#arr1 = [9,8,5,5,7]
arr1 = [2,2,2,2,2]
count1 = 0
l = OrderedDict()
for k, v in state_dict_quant.items():
    if isinstance(v, (nn.ReLU)):
        quant_layer = quant.NormalQuant('{}_quant'.format(k), bits=arr1[count1], quant_func=log_minmax_quantize)
        l['{}_{}_quant'.format(k, type)] = quant_layer
        count1 += 1
        if (count1 == 5):
            count1= 0
        l[k]= v