_combine_histograms histogram_with_output_range = torch.zeros((Nbins * downsample_rate), device=orig_hist.device) RuntimeError: Trying to create tensor with negative dimension -4398046511104: [-4398046511104]

Hello everyone.
This is a followup question concerning this one
The issue is everything goes just fine expect at some point in time, this weird error occurs when running this specific block! :

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mult_xy = nn.quantized.FloatFunctional()

        self.fc = nn.Sequential(
                                nn.Linear(channel, channel // reduction),
                                nn.PReLU(),
                                # nn.ReLU(),
                                nn.Linear(channel // reduction, channel),
                                nn.Sigmoid()
                                )
        self.fc1 = self.fc[0]
        self.prelu = self.fc[1]
        self.fc2 = self.fc[2]
        self.sigmoid = self.fc[3]
        self.prelu_q = PReLU_Quantized(self.prelu)

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        # y = self.fc(y).view(b, c, 1, 1)
        y = self.fc1(y)
        y = self.prelu_q(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)

        # out = x*y 
        out = self.mult_xy.mul(x, y)
        return out

It runs several times fine, but at some point it fails with the following error message :

Traceback (most recent call last):
  File "d:\Codes\org\python\Quantization\quantizer.py", line 248, in <module>
    quantize_test()
  File "d:\Codes\org\python\Quantization\quantizer.py", line 230, in quantize_test
    evaluate(model, dtloader, neval_batches=num_calibration_batches)
  File "d:\Codes\org\python\Quantization\quantizer.py", line 145, in evaluate
    features = model(image.unsqueeze(0))
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 473, in forward
    x = self.layer3(x)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\container.py", line 100, in forward
    input = module(input)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 387, in forward
    out = self.se(out)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 345, in forward
    y = self.prelu_q(y)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 221, in forward
    inputs = self.quantized_op.add(tmax, weight_min_res).unsqueeze(0)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\quantized\modules\functional_modules.py", line 43, in add
    r = self.activation_post_process(r)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\quantization\observer.py", line 833, in forward
    self.bins)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\quantization\observer.py", line 789, in _combine_histograms
    histogram_with_output_range = torch.zeros((Nbins * downsample_rate), device=orig_hist.device)
RuntimeError: Trying to create tensor with negative dimension -4398046511104: [-4398046511104]

what am I missing here ?

Any help is geatly appreciated

Hi,

Can you check what are the inputs to the add operation at "d:\codes\org\python\FV\quantized_models.py", line 221
It looks like it is not handling these properly.

If you could give us a set of inputs that reproduces this issue so that we can reproduce on our side, that would be very helpful!

Hi,
This is the latest error I get (after updating the PReLU_Quantized (link to implementation is here by the way):
The inputs are included in the log below (as X) and the error only happens in SEBlock module which its definition is also given below all other modules that use the PReLU_Quantzied module run fine except SEBlock!:

Size (MB): 89.297826
QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
Post Training Quantization Prepare: Inserting Observers

 Inverted Residual Block:After observer insertion

 Conv2d(
  3, 64, kernel_size=(3, 3), stride=(1, 1)
  (activation_post_process): HistogramObserver()
)
<inside se forward:>
X: tensor([[-1.5691, -0.7516, -0.7360, -0.6458]])
--------------------------
<inside se forward:>
X: tensor([[ 3.6605e-01,  3.3855e+00, -5.0032e-19, -9.0280e-19]])
Traceback (most recent call last):
  File "d:\Codes\org\python\Quantization\quantizer.py", line 266, in <module>
    quantize_test()
  File "d:\Codes\org\python\Quantization\quantizer.py", line 248, in quantize_test
    evaluate(model, dtloader, neval_batches=num_calibration_batches)
  File "d:\Codes\org\python\Quantization\quantizer.py", line 152, in evaluate
    features = model(image.unsqueeze(0))
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 576, in forward
    x = self.layer1(x)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\container.py", line 100, in forward
    input = module(input)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 489, in forward
    out = self.se(out)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 447, in forward
    y = self.prelu_q(y)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 322, in forward
    inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\quantized\modules\functional_modules.py", line 43, in add
    r = self.activation_post_process(r)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\quantization\observer.py", line 833, in forward
    self.bins)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\quantization\observer.py", line 789, in _combine_histograms
    histogram_with_output_range = torch.zeros((Nbins * downsample_rate), device=orig_hist.device)
RuntimeError: Trying to create tensor with negative dimension -4398046511104: [-4398046511104]

and this is how the SE block looks like :

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mult_xy = nn.quantized.FloatFunctional()

        self.fc = nn.Sequential(
                                nn.Linear(channel, channel // reduction),
                                nn.PReLU(),
                                # nn.ReLU(),
                                nn.Linear(channel // reduction, channel),
                                nn.Sigmoid()
                                )
        self.fc1 = self.fc[0]
        self.prelu = self.fc[1]
        self.fc2 = self.fc[2]
        self.sigmoid = self.fc[3]
        self.prelu_q = PReLU_Quantized(self.prelu)

    def forward(self, x):
        print(f'<inside se forward:>')
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        # y = self.fc(y).view(b, c, 1, 1)
        y = self.fc1(y)
        print(f'X: {y}')
        y = self.prelu_q(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        print('--------------------------')
        # out = x*y 
        out = self.mult_xy.mul(x, y)
        return out

amd this is the output when I use PReLU instead of PReLU_Quantized in the SE block only (all other instance of PReLU is replaced with PReLU_Quantized in other modulels of ResNet) :

Summary
Size (MB): 89.29209
QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
Post Training Quantization Prepare: Inserting Observers

 Inverted Residual Block:After observer insertion

 Conv2d(
  3, 64, kernel_size=(3, 3), stride=(1, 1)
  (activation_post_process): HistogramObserver()
)
<inside se forward:>
X: tensor([[-1.5691, -0.7516, -0.7360, -0.6458]])
--------------------------
<inside se forward:>
X: tensor([[ 3.6605e-01,  3.3855e+00, -5.0032e-19, -9.0280e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0513, -0.0656, -0.4529,  0.0653, -0.4762, -0.6304, -1.5043, -0.9484]])
--------------------------
<inside se forward:>
X: tensor([[ 4.8730,  1.6650, -0.5135, -0.6811, -0.0392, -0.4689, -0.1496,  0.0717]])
--------------------------
<inside se forward:>
X: tensor([[-1.8759, -0.8886, -1.3295, -0.5375,  0.7598, -0.8526, -1.9066,  0.0985,
         -0.1461, -0.5857,  0.1513, -0.3050,  0.1955, -0.8470,  0.4528,  0.9689]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6184e+00, -2.2714e-18,  2.8052e+00,  1.0378e+01,  4.6361e-05,
          1.0644e+01,  1.4302e-02,  2.6143e-02,  2.4926e-05,  6.2237e+00,
          8.8411e-05,  6.4360e+00,  3.3530e+00,  3.9302e-05,  8.1652e+00,
          8.7950e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1687e+00,  3.1469e+00, -1.1788e+01,  4.9410e-02,  1.7272e+00,
         -3.0913e+00,  1.1572e+00, -6.7104e+00,  1.1371e+01,  4.8926e+00,
         -1.3102e+00, -4.9774e+00, -4.1444e+00, -6.3367e-01, -1.5672e+00,
          4.2629e+00,  3.2491e+00, -4.6632e+00,  5.9241e-01, -2.4883e+00,
          5.2599e+00, -7.1710e+00,  4.7197e+00,  7.2724e+00, -2.3363e+00,
         -2.2564e+00,  5.4431e+00, -2.2832e-12,  1.9732e+00,  1.1682e+00,
          6.1555e+00,  6.3574e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2785e-01,  1.1057e+00,  3.1581e-07,  9.7595e-01,  9.7386e-03,
          8.4260e-07,  2.4243e-01,  2.1749e+00,  4.5704e-01,  2.9307e+00,
          3.2384e+00,  2.6099e+00,  1.7640e-01,  4.3206e-04,  9.9380e-18,
          1.3450e-11,  1.5721e-09,  2.7632e-07,  3.6721e-04,  2.1237e-07,
          1.8839e-10,  1.8423e-02,  1.8514e-13,  4.3584e+00,  1.0972e-01,
          7.5909e-03,  4.3828e-02,  2.9285e-02,  8.3840e-07, -2.6420e-19,
          3.6933e-01,  1.0561e+00]])
--------------------------
0-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.5517, -0.8007, -0.7286, -0.6478]])
--------------------------
<inside se forward:>
X: tensor([[ 5.0945e-01,  3.2514e+00, -5.2950e-19, -9.1256e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0556, -0.1015, -0.4792,  0.0956, -0.4782, -0.6346, -1.4946, -0.9745]])
--------------------------
<inside se forward:>
X: tensor([[ 4.8254,  1.6459, -0.4613, -0.6462, -0.0376, -0.4217, -0.0865,  0.0773]])
--------------------------
<inside se forward:>
X: tensor([[-1.8807, -0.8899, -1.3275, -0.5305,  0.7527, -0.8557, -1.9068,  0.1042,
         -0.1444, -0.5798,  0.1493, -0.3055,  0.1952, -0.8383,  0.4532,  0.9664]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6193e+00, -2.2732e-18,  2.8069e+00,  1.0384e+01,  4.6389e-05,
          1.0650e+01,  1.4310e-02,  2.6159e-02,  2.4941e-05,  6.2275e+00,
          8.8464e-05,  6.4398e+00,  3.3551e+00,  3.9326e-05,  8.1701e+00,
          8.8003e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1444e+00,  3.1584e+00, -1.1794e+01,  4.9510e-02,  1.7366e+00,
         -3.0976e+00,  1.1594e+00, -6.7127e+00,  1.1380e+01,  4.9035e+00,
         -1.3231e+00, -4.9740e+00, -4.1439e+00, -6.3774e-01, -1.5777e+00,
          4.2655e+00,  3.2341e+00, -4.6753e+00,  6.1677e-01, -2.4898e+00,
          5.2556e+00, -7.1508e+00,  4.7271e+00,  7.2643e+00, -2.3301e+00,
         -2.2546e+00,  5.4412e+00, -2.2872e-12,  1.9668e+00,  1.1764e+00,
          6.1590e+00,  6.3575e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2778e-01,  1.1051e+00,  3.1564e-07,  9.7544e-01,  9.7335e-03,
          8.4216e-07,  2.4230e-01,  2.1737e+00,  4.5681e-01,  2.9292e+00,
          3.2367e+00,  2.6086e+00,  1.7631e-01,  4.3183e-04,  9.9393e-18,
          1.3443e-11,  1.5713e-09,  2.7617e-07,  3.6702e-04,  2.1226e-07,
          1.8829e-10,  1.8414e-02,  1.8504e-13,  4.3561e+00,  1.0967e-01,
          7.5869e-03,  4.3805e-02,  2.9270e-02,  8.3797e-07, -2.6259e-19,
          3.6914e-01,  1.0555e+00]])
--------------------------
1-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.6008, -0.7627, -0.7418, -0.6562]])
--------------------------
<inside se forward:>
X: tensor([[ 4.6180e-01,  3.2969e+00, -5.1091e-19, -8.5673e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0860, -0.0888, -0.4410,  0.0515, -0.4853, -0.6203, -1.4854, -0.9521]])
--------------------------
<inside se forward:>
X: tensor([[ 4.8713,  1.6702, -0.5249, -0.6848, -0.0393, -0.4817, -0.1603,  0.0686]])
--------------------------
<inside se forward:>
X: tensor([[-1.8888, -0.8991, -1.3308, -0.5351,  0.7626, -0.8547, -1.9075,  0.1075,
         -0.1457, -0.5770,  0.1518, -0.3068,  0.2023, -0.8418,  0.4610,  0.9654]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6180e+00, -2.2720e-18,  2.8046e+00,  1.0376e+01,  4.6351e-05,
          1.0642e+01,  1.4299e-02,  2.6138e-02,  2.4921e-05,  6.2225e+00,
          8.8393e-05,  6.4347e+00,  3.3524e+00,  3.9294e-05,  8.1636e+00,
          8.7932e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1925e+00,  3.1589e+00, -1.1792e+01,  4.9472e-02,  1.7246e+00,
         -3.0884e+00,  1.1586e+00, -6.7112e+00,  1.1375e+01,  4.8954e+00,
         -1.3047e+00, -4.9715e+00, -4.1392e+00, -6.4653e-01, -1.5772e+00,
          4.2795e+00,  3.2537e+00, -4.6607e+00,  5.9939e-01, -2.4853e+00,
          5.2615e+00, -7.1921e+00,  4.7311e+00,  7.2626e+00, -2.3221e+00,
         -2.2574e+00,  5.4390e+00, -2.2799e-12,  1.9636e+00,  1.1820e+00,
          6.1593e+00,  6.3554e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2775e-01,  1.1048e+00,  3.1557e-07,  9.7521e-01,  9.7312e-03,
          8.4196e-07,  2.4225e-01,  2.1732e+00,  4.5670e-01,  2.9285e+00,
          3.2360e+00,  2.6079e+00,  1.7627e-01,  4.3173e-04,  9.9377e-18,
          1.3440e-11,  1.5710e-09,  2.7611e-07,  3.6693e-04,  2.1221e-07,
          1.8825e-10,  1.8409e-02,  1.8500e-13,  4.3551e+00,  1.0964e-01,
          7.5851e-03,  4.3795e-02,  2.9263e-02,  8.3777e-07, -2.6075e-19,
          3.6905e-01,  1.0553e+00]])
--------------------------
2-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.5790, -0.8100, -0.7292, -0.6440]])
--------------------------
<inside se forward:>
X: tensor([[ 5.0116e-01,  3.2659e+00, -5.2126e-19, -8.5920e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0427, -0.0929, -0.4953,  0.0674, -0.4784, -0.6115, -1.4972, -0.9645]])
--------------------------
<inside se forward:>
X: tensor([[ 4.8374,  1.6551, -0.4788, -0.6555, -0.0380, -0.4393, -0.1045,  0.0742]])
--------------------------
<inside se forward:>
X: tensor([[-1.8727, -0.8932, -1.3280, -0.5371,  0.7591, -0.8533, -1.8998,  0.1003,
         -0.1452, -0.5813,  0.1475, -0.3055,  0.2016, -0.8411,  0.4535,  0.9559]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6189e+00, -2.2717e-18,  2.8060e+00,  1.0381e+01,  4.6375e-05,
          1.0647e+01,  1.4306e-02,  2.6151e-02,  2.4933e-05,  6.2256e+00,
          8.8438e-05,  6.4379e+00,  3.3541e+00,  3.9314e-05,  8.1676e+00,
          8.7976e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1427e+00,  3.1480e+00, -1.1763e+01,  4.9449e-02,  1.7342e+00,
         -3.0890e+00,  1.1581e+00, -6.7127e+00,  1.1348e+01,  4.8951e+00,
         -1.3154e+00, -4.9691e+00, -4.1414e+00, -6.4151e-01, -1.5783e+00,
          4.2688e+00,  3.2439e+00, -4.6649e+00,  6.0231e-01, -2.4855e+00,
          5.2647e+00, -7.1494e+00,  4.7290e+00,  7.2520e+00, -2.3288e+00,
         -2.2466e+00,  5.4410e+00, -2.2847e-12,  1.9777e+00,  1.1817e+00,
          6.1588e+00,  6.3552e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2778e-01,  1.1050e+00,  3.1563e-07,  9.7541e-01,  9.7331e-03,
          8.4213e-07,  2.4229e-01,  2.1736e+00,  4.5679e-01,  2.9291e+00,
          3.2366e+00,  2.6084e+00,  1.7631e-01,  4.3181e-04,  9.9368e-18,
          1.3443e-11,  1.5713e-09,  2.7616e-07,  3.6700e-04,  2.1225e-07,
          1.8828e-10,  1.8413e-02,  1.8503e-13,  4.3559e+00,  1.0966e-01,
          7.5866e-03,  4.3804e-02,  2.9269e-02,  8.3793e-07, -2.6231e-19,
          3.6912e-01,  1.0555e+00]])
--------------------------
3-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.6226, -0.7605, -0.6854, -0.5836]])
--------------------------
<inside se forward:>
X: tensor([[ 2.6039e-01,  3.4835e+00, -4.8167e-19, -8.5980e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0699, -0.0526, -0.4319, -0.0069, -0.4890, -0.6087, -1.4835, -0.9184]])
--------------------------
<inside se forward:>
X: tensor([[ 4.8828,  1.6724, -0.5539, -0.7054, -0.0402, -0.5061, -0.2002,  0.0661]])
--------------------------
<inside se forward:>
X: tensor([[-1.8790, -0.8969, -1.3365, -0.5384,  0.7664, -0.8571, -1.9043,  0.1059,
         -0.1459, -0.5847,  0.1542, -0.3094,  0.2076, -0.8439,  0.4567,  0.9642]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6174e+00, -2.2710e-18,  2.8035e+00,  1.0371e+01,  4.6333e-05,
          1.0638e+01,  1.4293e-02,  2.6128e-02,  2.4911e-05,  6.2200e+00,
          8.8358e-05,  6.4321e+00,  3.3510e+00,  3.9279e-05,  8.1603e+00,
          8.7897e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1583e+00,  3.1523e+00, -1.1765e+01,  4.9511e-02,  1.7292e+00,
         -3.0851e+00,  1.1595e+00, -6.7154e+00,  1.1350e+01,  4.9005e+00,
         -1.3040e+00, -4.9675e+00, -4.1433e+00, -6.3643e-01, -1.5745e+00,
          4.2669e+00,  3.2492e+00, -4.6569e+00,  6.0002e-01, -2.4789e+00,
          5.2519e+00, -7.1619e+00,  4.7275e+00,  7.2465e+00, -2.3229e+00,
         -2.2525e+00,  5.4448e+00, -2.2806e-12,  1.9732e+00,  1.1739e+00,
          6.1550e+00,  6.3576e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2778e-01,  1.1050e+00,  3.1563e-07,  9.7540e-01,  9.7331e-03,
          8.4212e-07,  2.4229e-01,  2.1736e+00,  4.5679e-01,  2.9291e+00,
          3.2366e+00,  2.6084e+00,  1.7630e-01,  4.3181e-04,  9.9369e-18,
          1.3443e-11,  1.5712e-09,  2.7616e-07,  3.6700e-04,  2.1225e-07,
          1.8828e-10,  1.8413e-02,  1.8503e-13,  4.3559e+00,  1.0966e-01,
          7.5866e-03,  4.3803e-02,  2.9269e-02,  8.3793e-07, -2.6177e-19,
          3.6912e-01,  1.0555e+00]])
--------------------------
4-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.5559, -0.7016, -0.7545, -0.6793]])
--------------------------
<inside se forward:>
X: tensor([[ 4.6992e-01,  3.2951e+00, -5.1868e-19, -8.9299e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0106, -0.0831, -0.5151,  0.0650, -0.4869, -0.6094, -1.5116, -0.9355]])
--------------------------
<inside se forward:>
X: tensor([[ 4.8588,  1.6723, -0.4774, -0.6520, -0.0379, -0.4428, -0.0917,  0.0721]])
--------------------------
<inside se forward:>
X: tensor([[-1.8655, -0.8893, -1.3313, -0.5367,  0.7590, -0.8533, -1.9023,  0.1008,
         -0.1428, -0.5834,  0.1448, -0.3016,  0.2040, -0.8361,  0.4534,  0.9494]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6194e+00, -2.2728e-18,  2.8070e+00,  1.0384e+01,  4.6391e-05,
          1.0651e+01,  1.4311e-02,  2.6160e-02,  2.4942e-05,  6.2277e+00,
          8.8468e-05,  6.4401e+00,  3.3552e+00,  3.9328e-05,  8.1704e+00,
          8.8006e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1170e+00,  3.1500e+00, -1.1769e+01,  4.9446e-02,  1.7362e+00,
         -3.0951e+00,  1.1581e+00, -6.7183e+00,  1.1354e+01,  4.8964e+00,
         -1.3110e+00, -4.9689e+00, -4.1461e+00, -6.4890e-01, -1.5875e+00,
          4.2782e+00,  3.2361e+00, -4.6685e+00,  6.0150e-01, -2.4799e+00,
          5.2726e+00, -7.1287e+00,  4.7384e+00,  7.2532e+00, -2.3235e+00,
         -2.2367e+00,  5.4463e+00, -2.2915e-12,  1.9780e+00,  1.1893e+00,
          6.1668e+00,  6.3629e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2774e-01,  1.1047e+00,  3.1554e-07,  9.7513e-01,  9.7304e-03,
          8.4189e-07,  2.4223e-01,  2.1730e+00,  4.5666e-01,  2.9283e+00,
          3.2357e+00,  2.6077e+00,  1.7626e-01,  4.3169e-04,  9.9353e-18,
          1.3439e-11,  1.5708e-09,  2.7609e-07,  3.6690e-04,  2.1219e-07,
          1.8823e-10,  1.8408e-02,  1.8498e-13,  4.3547e+00,  1.0963e-01,
          7.5845e-03,  4.3792e-02,  2.9261e-02,  8.3770e-07, -2.6081e-19,
          3.6902e-01,  1.0552e+00]])
--------------------------
5-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.5922, -0.7833, -0.8099, -0.7581]])
--------------------------
<inside se forward:>
X: tensor([[ 6.0425e-01,  3.1537e+00, -5.2917e-19, -8.2412e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0295, -0.1079, -0.5239,  0.1099, -0.4906, -0.6187, -1.5178, -0.9515]])
--------------------------
<inside se forward:>
X: tensor([[ 4.9047,  1.7059, -0.4654, -0.6338, -0.0371, -0.4419, -0.0531,  0.0689]])
--------------------------
<inside se forward:>
X: tensor([[-1.8792, -0.8972, -1.3274, -0.5352,  0.7649, -0.8542, -1.9078,  0.1055,
         -0.1455, -0.5737,  0.1437, -0.3026,  0.2050, -0.8408,  0.4609,  0.9527]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6192e+00, -2.2734e-18,  2.8065e+00,  1.0383e+01,  4.6383e-05,
          1.0649e+01,  1.4309e-02,  2.6156e-02,  2.4938e-05,  6.2268e+00,
          8.8454e-05,  6.4391e+00,  3.3547e+00,  3.9321e-05,  8.1692e+00,
          8.7993e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1462e+00,  3.1594e+00, -1.1777e+01,  4.9445e-02,  1.7250e+00,
         -3.0903e+00,  1.1580e+00, -6.6971e+00,  1.1362e+01,  4.8978e+00,
         -1.3202e+00, -4.9701e+00, -4.1377e+00, -6.3982e-01, -1.5717e+00,
          4.2688e+00,  3.2314e+00, -4.6666e+00,  6.1283e-01, -2.4762e+00,
          5.2739e+00, -7.1517e+00,  4.7211e+00,  7.2673e+00, -2.3338e+00,
         -2.2474e+00,  5.4291e+00, -2.2837e-12,  1.9676e+00,  1.1787e+00,
          6.1559e+00,  6.3495e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2772e-01,  1.1046e+00,  3.1549e-07,  9.7498e-01,  9.7289e-03,
          8.4176e-07,  2.4219e-01,  2.1727e+00,  4.5659e-01,  2.9278e+00,
          3.2352e+00,  2.6073e+00,  1.7623e-01,  4.3163e-04,  9.9370e-18,
          1.3437e-11,  1.5706e-09,  2.7604e-07,  3.6684e-04,  2.1216e-07,
          1.8820e-10,  1.8405e-02,  1.8495e-13,  4.3541e+00,  1.0962e-01,
          7.5833e-03,  4.3785e-02,  2.9256e-02,  8.3757e-07, -2.6052e-19,
          3.6896e-01,  1.0550e+00]])
--------------------------
6-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.5156, -0.5839, -0.7718, -0.6881]])
--------------------------
<inside se forward:>
X: tensor([[ 6.3789e-01,  3.1470e+00, -5.4607e-19, -8.8140e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0068, -0.1239, -0.5419,  0.1311, -0.4739, -0.6220, -1.5159, -1.0039]])
--------------------------
<inside se forward:>
X: tensor([[ 4.7764,  1.6289, -0.3860, -0.5940, -0.0352, -0.3554,  0.0103,  0.0848]])
--------------------------
<inside se forward:>
X: tensor([[-1.8759, -0.8883, -1.3219, -0.5339,  0.7527, -0.8555, -1.9051,  0.0963,
         -0.1418, -0.5765,  0.1501, -0.2970,  0.1911, -0.8370,  0.4527,  0.9548]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6203e+00, -2.2739e-18,  2.8086e+00,  1.0390e+01,  4.6417e-05,
          1.0657e+01,  1.4319e-02,  2.6175e-02,  2.4956e-05,  6.2312e+00,
          8.8518e-05,  6.4437e+00,  3.3571e+00,  3.9350e-05,  8.1750e+00,
          8.8056e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1541e+00,  3.1531e+00, -1.1772e+01,  4.9404e-02,  1.7326e+00,
         -3.0931e+00,  1.1571e+00, -6.6943e+00,  1.1357e+01,  4.8937e+00,
         -1.3274e+00, -4.9758e+00, -4.1305e+00, -6.4647e-01, -1.5764e+00,
          4.2726e+00,  3.2396e+00, -4.6719e+00,  6.0704e-01, -2.4865e+00,
          5.2721e+00, -7.1595e+00,  4.7218e+00,  7.2695e+00, -2.3445e+00,
         -2.2482e+00,  5.4221e+00, -2.2827e-12,  1.9751e+00,  1.1886e+00,
          6.1566e+00,  6.3400e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2782e-01,  1.1054e+00,  3.1574e-07,  9.7576e-01,  9.7366e-03,
          8.4243e-07,  2.4238e-01,  2.1744e+00,  4.5695e-01,  2.9301e+00,
          3.2378e+00,  2.6094e+00,  1.7637e-01,  4.3197e-04,  9.9450e-18,
          1.3448e-11,  1.5718e-09,  2.7626e-07,  3.6714e-04,  2.1232e-07,
          1.8835e-10,  1.8419e-02,  1.8510e-13,  4.3575e+00,  1.0970e-01,
          7.5893e-03,  4.3819e-02,  2.9279e-02,  8.3823e-07, -2.6201e-19,
          3.6925e-01,  1.0558e+00]])
--------------------------
7-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.5567, -0.7524, -0.7620, -0.6805]])
--------------------------
<inside se forward:>
X: tensor([[ 5.3279e-01,  3.2445e+00, -5.2411e-19, -8.5973e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0248, -0.1011, -0.5172,  0.0823, -0.4737, -0.6192, -1.4961, -0.9762]])
--------------------------
<inside se forward:>
X: tensor([[ 4.8410,  1.6705, -0.4254, -0.6104, -0.0360, -0.4001, -0.0183,  0.0759]])
--------------------------
<inside se forward:>
X: tensor([[-1.8740, -0.8943, -1.3243, -0.5337,  0.7550, -0.8610, -1.9063,  0.1108,
         -0.1408, -0.5770,  0.1506, -0.3089,  0.1984, -0.8347,  0.4544,  0.9591]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6191e+00, -2.2732e-18,  2.8065e+00,  1.0383e+01,  4.6383e-05,
          1.0649e+01,  1.4309e-02,  2.6156e-02,  2.4938e-05,  6.2267e+00,
          8.8453e-05,  6.4390e+00,  3.3546e+00,  3.9321e-05,  8.1691e+00,
          8.7992e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1553e+00,  3.1582e+00, -1.1776e+01,  4.9516e-02,  1.7335e+00,
         -3.0909e+00,  1.1595e+00, -6.7080e+00,  1.1362e+01,  4.9002e+00,
         -1.3237e+00, -4.9679e+00, -4.1376e+00, -6.4026e-01, -1.5758e+00,
          4.2652e+00,  3.2360e+00, -4.6691e+00,  6.1957e-01, -2.4899e+00,
          5.2536e+00, -7.1605e+00,  4.7257e+00,  7.2488e+00, -2.3271e+00,
         -2.2548e+00,  5.4335e+00, -2.2811e-12,  1.9611e+00,  1.1809e+00,
          6.1551e+00,  6.3494e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2783e-01,  1.1055e+00,  3.1575e-07,  9.7577e-01,  9.7367e-03,
          8.4244e-07,  2.4238e-01,  2.1744e+00,  4.5696e-01,  2.9302e+00,
          3.2378e+00,  2.6094e+00,  1.7637e-01,  4.3197e-04,  9.9456e-18,
          1.3448e-11,  1.5718e-09,  2.7626e-07,  3.6714e-04,  2.1233e-07,
          1.8835e-10,  1.8420e-02,  1.8510e-13,  4.3576e+00,  1.0970e-01,
          7.5894e-03,  4.3820e-02,  2.9280e-02,  8.3824e-07, -2.6233e-19,
          3.6926e-01,  1.0559e+00]])
--------------------------
8-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.6060, -0.9100, -0.7711, -0.7195]])
--------------------------
<inside se forward:>
X: tensor([[ 5.6481e-01,  3.2033e+00, -5.2471e-19, -8.0308e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0948, -0.1106, -0.4654,  0.0768, -0.5028, -0.6202, -1.4778, -0.9581]])
--------------------------
<inside se forward:>
X: tensor([[ 4.9064,  1.6963, -0.5052, -0.6644, -0.0385, -0.4721, -0.1160,  0.0672]])
--------------------------
<inside se forward:>
X: tensor([[-1.8868, -0.8981, -1.3322, -0.5298,  0.7566, -0.8556, -1.9039,  0.1134,
         -0.1447, -0.5744,  0.1480, -0.3113,  0.2017, -0.8359,  0.4564,  0.9658]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6177e+00, -2.2718e-18,  2.8040e+00,  1.0373e+01,  4.6342e-05,
          1.0640e+01,  1.4296e-02,  2.6133e-02,  2.4916e-05,  6.2212e+00,
          8.8375e-05,  6.4333e+00,  3.3517e+00,  3.9286e-05,  8.1619e+00,
          8.7914e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1471e+00,  3.1531e+00, -1.1779e+01,  4.9447e-02,  1.7370e+00,
         -3.0912e+00,  1.1580e+00, -6.7101e+00,  1.1363e+01,  4.9010e+00,
         -1.3083e+00, -4.9699e+00, -4.1370e+00, -6.3986e-01, -1.5794e+00,
          4.2680e+00,  3.2415e+00, -4.6646e+00,  6.0562e-01, -2.4862e+00,
          5.2591e+00, -7.1519e+00,  4.7275e+00,  7.2529e+00, -2.3203e+00,
         -2.2537e+00,  5.4380e+00, -2.2843e-12,  1.9685e+00,  1.1793e+00,
          6.1543e+00,  6.3497e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2776e-01,  1.1049e+00,  3.1558e-07,  9.7525e-01,  9.7316e-03,
          8.4199e-07,  2.4226e-01,  2.1733e+00,  4.5672e-01,  2.9286e+00,
          3.2361e+00,  2.6080e+00,  1.7628e-01,  4.3175e-04,  9.9388e-18,
          1.3441e-11,  1.5710e-09,  2.7612e-07,  3.6695e-04,  2.1221e-07,
          1.8825e-10,  1.8410e-02,  1.8500e-13,  4.3553e+00,  1.0965e-01,
          7.5854e-03,  4.3797e-02,  2.9264e-02,  8.3780e-07, -2.6109e-19,
          3.6906e-01,  1.0553e+00]])
--------------------------
9-feature dims: torch.Size([1, 512])
Post Training Quantization: Calibration done
C:\Users\User\Anaconda3\Lib\site-packages\torch\quantization\observer.py:845: UserWarning: must run observer before calling calculate_qparams.
     Returning default scale and zero point
  Returning default scale and zero point "
Post Training Quantization: Convert done

 Inverted Residual Block: After fusion and quantization, note fused modules:

 QuantizedConv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.011990774422883987, zero_point=80)
Size of model after quantization
Size (MB): 24.397458

Which version of PyTorch are you currently using? We recently fixed a bug in the histogram observer that should be available in 1.6. You can also use nightlies to see if it fixes your issue.

1 Like

Hi, I’m using 1.5.0!
Ok, I’ll give that a try and report back

Updated to the latest nighly(1.7.0.dev20200714+cpu and torchvision-0.8.0.dev20200714+cpu) just now , it got a bit further, but ultimately crashed with the same error :

Size (MB): 89.322487
QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
Post Training Quantization Prepare: Inserting Observers

 Inverted Residual Block:After observer insertion

 Conv2d(
  3, 64, kernel_size=(3, 3), stride=(1, 1)
  (activation_post_process): HistogramObserver()
)
<inside se forward:>
X: tensor([[-1.5691, -0.7516, -0.7360, -0.6458]])
--------------------------
<inside se forward:>
X: tensor([[ 3.6604e-01,  3.3855e+00, -5.0032e-19, -9.0280e-19]])
--------------------------
<inside se forward:>
X: tensor([[-1.0513, -0.0656, -0.4529,  0.0653, -0.4762, -0.6304, -1.5043, -0.9484]])
--------------------------
<inside se forward:>
X: tensor([[ 4.8730,  1.6650, -0.5135, -0.6811, -0.0392, -0.4689, -0.1496,  0.0717]])
--------------------------
<inside se forward:>
X: tensor([[-1.8759, -0.8886, -1.3295, -0.5375,  0.7598, -0.8526, -1.9066,  0.0985,
         -0.1461, -0.5857,  0.1513, -0.3050,  0.1955, -0.8470,  0.4528,  0.9689]])
--------------------------
<inside se forward:>
X: tensor([[ 1.6184e+00, -2.2714e-18,  2.8052e+00,  1.0378e+01,  4.6361e-05,
          1.0644e+01,  1.4302e-02,  2.6143e-02,  2.4926e-05,  6.2237e+00,
          8.8411e-05,  6.4360e+00,  3.3530e+00,  3.9302e-05,  8.1652e+00,
          8.7950e-07]])
--------------------------
<inside se forward:>
X: tensor([[ 9.1687e+00,  3.1469e+00, -1.1788e+01,  4.9410e-02,  1.7272e+00,
         -3.0913e+00,  1.1572e+00, -6.7104e+00,  1.1371e+01,  4.8926e+00,
         -1.3102e+00, -4.9773e+00, -4.1444e+00, -6.3367e-01, -1.5672e+00,
          4.2629e+00,  3.2491e+00, -4.6632e+00,  5.9241e-01, -2.4883e+00,
          5.2599e+00, -7.1710e+00,  4.7197e+00,  7.2724e+00, -2.3363e+00,
         -2.2564e+00,  5.4431e+00, -2.2832e-12,  1.9732e+00,  1.1682e+00,
          6.1555e+00,  6.3574e+00]])
--------------------------
<inside se forward:>
X: tensor([[ 1.2785e-01,  1.1057e+00,  3.1581e-07,  9.7595e-01,  9.7386e-03,
          8.4260e-07,  2.4243e-01,  2.1749e+00,  4.5704e-01,  2.9307e+00,
          3.2384e+00,  2.6099e+00,  1.7640e-01,  4.3206e-04,  9.9380e-18,
          1.3450e-11,  1.5721e-09,  2.7632e-07,  3.6721e-04,  2.1237e-07,
          1.8839e-10,  1.8423e-02,  1.8514e-13,  4.3584e+00,  1.0972e-01,
          7.5909e-03,  4.3828e-02,  2.9285e-02,  8.3840e-07, -2.6420e-19,
          3.6933e-01,  1.0561e+00]])
--------------------------
0-feature dims: torch.Size([1, 512])
<inside se forward:>
X: tensor([[-1.5517, -0.8007, -0.7286, -0.6478]])
--------------------------
<inside se forward:>
X: tensor([[ 5.0945e-01,  3.2514e+00, -5.2950e-19, -9.1256e-19]])
Traceback (most recent call last):
  File "d:\Codes\org\python\Quantization\quantizer.py", line 266, in <module>
    quantize_test()
  File "d:\Codes\org\python\Quantization\quantizer.py", line 248, in quantize_test
    evaluate(model, dtloader, neval_batches=num_calibration_batches)
  File "d:\Codes\org\python\Quantization\quantizer.py", line 152, in evaluate
    features = model(image.unsqueeze(0))
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 576, in forward
    x = self.layer1(x)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\container.py", line 117, in forward
    input = module(input)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 489, in forward
    out = self.se(out)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 447, in forward
    y = self.prelu_q(y)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "d:\codes\org\python\FV\quantized_models.py", line 322, in forward
    inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\quantized\modules\functional_modules.py", line 46, in add
    r = self.activation_post_process(r)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 726, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\quantization\observer.py", line 862, in forward
    self.bins)
  File "C:\Users\User\Anaconda3\Lib\site-packages\torch\quantization\observer.py", line 813, in _combine_histograms
    histogram_with_output_range = torch.zeros((Nbins * downsample_rate), device=orig_hist.device)
RuntimeError: Trying to create tensor with negative dimension -4398046511104: [-4398046511104]

@supriyar Any ideas whats the problem here?
its greatly appreciated

1 Like

The initial error was due to the histogram observer getting a tensor with same values or all zero values. But since that was fixed I am not quite sure of the cause of this error.
Could you provide a small repro for us to take a look? Along with the input tensor data for which this error shows up. Thanks!

seems updating to 1.7 solved this issue! Hoever, the Unimplemented type and native bn related issues are still present.
I created a minimal self contained example with Resnet18 and a simple 2 layered Network from quantizing the model to testing it using fake data.
By setting the two variables at the top (use_relu, disable_single_bn) you can see different behaviors(most of the code is biolerplates and resnet18 definitions)
you are free to test this both with the ResNet18 or the SimpleNetwork:

import os
from os.path import abspath, dirname, join
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import FakeData
import torchvision.transforms as transforms
from torch.quantization import fuse_modules
use_relu = False
disable_single_bns = False

class PReLU_Quantized(nn.Module):
    def __init__(self, prelu_object):
        super().__init__()
        self.prelu_weight = prelu_object.weight
        self.weight = self.prelu_weight
        self.quantized_op = nn.quantized.FloatFunctional()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, inputs):
        # inputs = max(0, inputs) + alpha * min(0, inputs) 
        # this is how we do it 
        # pos = torch.relu(inputs)
        # neg = -alpha * torch.relu(-inputs)
        # res3 = pos + neg
        self.weight = self.quant(self.weight)
        weight_min_res = self.quantized_op.mul(-self.weight, torch.relu(-inputs))
        inputs = self.quantized_op.add(torch.relu(inputs), weight_min_res)
        inputs = self.dequant(inputs)
        self.weight = self.dequant(self.weight)
        return inputs

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.add_relu = torch.nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        # out += residual
        # out = self.relu(out)
        out = self.add_relu.add_relu(out, residual)
        return out

    def fuse_model(self):
        torch.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu'],
                                               ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu1 = nn.ReLU(inplace=False)
        self.relu2 = nn.ReLU(inplace=False)
        self.downsample = downsample
        self.stride = stride
        self.skip_add_relu = nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.bn3(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        # out += residual
        # out = self.relu(out)
        out = self.skip_add_relu.add_relu(out, residual)
        return out

    def fuse_model(self):
        fuse_modules(self, [['conv1', 'bn1', 'relu1'],
                            ['conv2', 'bn2', 'relu2'],
                            ['conv3', 'bn3']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.mult_xy = nn.quantized.FloatFunctional()
        self.fc = nn.Sequential(nn.Linear(channel, channel // reduction),
                                nn.PReLU(),
                                nn.Linear(channel // reduction, channel),
                                nn.Sigmoid())
        self.fc1 = self.fc[0]
        self.prelu = self.fc[1]
        self.fc2 = self.fc[2]
        self.sigmoid = self.fc[3]
        self.prelu_q = PReLU_Quantized(self.prelu)
        if use_relu:
            self.prelu_q_or_relu = torch.relu
        else:
            self.prelu_q_or_relu = self.prelu_q

    def forward(self, x):
        # print(f'<inside se forward:>')
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        # y = self.fc(y).view(b, c, 1, 1)
        y = self.fc1(y)
        y = self.prelu_q_or_relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1)
        # print('--------------------------')
        # out = x*y 
        out = self.mult_xy.mul(x, y)
        return out

class IRBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
        super().__init__()
        self.bn0 = nn.BatchNorm2d(inplanes)
        if disable_single_bns:
            self.bn0_or_identity = torch.nn.Identity()
        else:
            self.bn0_or_identity = self.bn0

        self.conv1 = conv3x3(inplanes, inplanes)
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.prelu = nn.PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
        
        if use_relu:
            self.prelu_q_or_relu = torch.relu
        else:
            self.prelu_q_or_relu = self.prelu_q

        self.conv2 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        self.use_se = use_se
        # if self.use_se:
        self.se = SEBlock(planes)
        self.add_residual = nn.quantized.FloatFunctional()

    def forward(self, x):
        residual = x
        # TODO:
        # this needs to be quantized as well!
        out = self.bn0_or_identity(x)

        out = self.conv1(out)
        out = self.bn1(out)
        # out = self.prelu(out)
        out = self.prelu_q_or_relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se(out)
        if self.downsample is not None:
            residual = self.downsample(x)
        # out += residual
        # out = self.prelu(out)
        out = self.prelu_q_or_relu(out)
        # we may need to change prelu into relu and instead of add, use add_relu here
        out = self.add_residual.add(out, residual)
        return out

    def fuse_model(self):
        fuse_modules(self, [# ['bn0'],
                            ['conv1', 'bn1'],
                            ['conv2', 'bn2']], inplace=True)
        if self.downsample:
            torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)

class ResNet(nn.Module):

    def __init__(self, block, layers, use_se=True):
        self.inplanes = 64
        self.use_se = use_se
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.prelu = nn.PReLU()
        self.prelu_q = PReLU_Quantized(self.prelu)
        # This is to only get rid of the unimplemented CPUQuantization type error
        # when we use PReLU_Quantized during test time
        if use_relu:
            self.prelu_q_or_relu = torch.relu
        else:
             self.prelu_q_or_relu = self.prelu_q

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.bn2 = nn.BatchNorm2d(512)
        # This is to get around the single BatchNorms not getting fused and thus causing 
        # a RuntimeError: Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU' backend.
        # 'aten::native_batch_norm' is only available for these backends: [CPU, MkldnnCPU, BackendSelect, Named, Autograd, Profiler, Tracer, Autocast, Batched].
        # during test time
        if disable_single_bns:
            self.bn2_or_identity = torch.nn.Identity()
        else:
            self.bn2_or_identity = self.bn2

        self.dropout = nn.Dropout()
        self.fc = nn.Linear(512 * 7 * 7, 512)
        self.bn3 = nn.BatchNorm1d(512)
        if disable_single_bns:
            self.bn3_or_identity = torch.nn.Identity()
        else:
            self.bn3_or_identity = self.bn3
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, use_se=self.use_se))

        return nn.Sequential(*layers)

    def forward(self, x):
        
        x = self.quant(x)
        x = self.conv1(x)
        # TODO: single bn needs to be fused
        x = self.bn1(x)

        # x = self.prelu(x)
        x = self.prelu_q_or_relu(x)

        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.bn2_or_identity(x)
        x = self.dropout(x)
        # x = x.view(x.size(0), -1)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        # TODO: single bn needs to be fused
        x = self.bn3_or_identity(x)
        x = self.dequant(x)
        return x

    def fuse_model(self):
        r"""Fuse conv/bn/relu modules in resnet models
        Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.
        Model is modified in place.  Note that this operation does not change numerics
        and the model after modification is in floating point
        """
        fuse_modules(self, ['conv1', 'bn1'], inplace=True)
        for m in self.modules():
            if type(m) == Bottleneck or type(m) == BasicBlock or type(m) == IRBlock:
                m.fuse_model()

class SimpleNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(10)
        self.relu1 = nn.ReLU()

        self.prelu_q = PReLU_Quantized(nn.PReLU())
        self.bn = nn.BatchNorm2d(10)

        self.prelu_q_or_relu = torch.relu if use_relu else self.prelu_q
        self.bn_or_identity = nn.Identity() if disable_single_bns else self.bn    

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x):
        x = self.quant(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.prelu_q_or_relu(x)
        x = self.bn_or_identity(x)

        x = self.dequant(x)
        return x

def resnet18(use_se=True, **kwargs):
    return ResNet(IRBlock, [2, 2, 2, 2], use_se=use_se, **kwargs)

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def evaluate(model, data_loader, eval_batches):
    model.eval()
    with torch.no_grad():
        for i, (image, target) in enumerate(data_loader):
            features = model(image)
            print(f'{i})feature dims: {features.shape}')
            if i >= eval_batches:
                return

def load_quantized(model, quantized_checkpoint_file_path):
    model.eval()
    if type(model) == ResNet:
        model.fuse_model()
    # Specify quantization configuration
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    torch.quantization.prepare(model, inplace=True)
    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    checkpoint = torch.load(quantized_checkpoint_file_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint)
    print_size_of_model(model)
    return model

def test_the_model(model, dtloader):
    current_dir = abspath(dirname(__file__))
    model = load_quantized(model, join(current_dir, 'data', 'model_quantized_jit.pth'))
    model.eval()
    img, _ = next(iter(dtloader))
    embd1 = model(img)

def quantize_model(model, dtloader):
    calibration_batches = 10 
    saved_model_dir = 'data'
    scripted_quantized_model_file = 'model_quantized_jit.pth'
    # model = resnet18()
    model.eval()
    if type(model) == ResNet:
        model.fuse_model()
    print_size_of_model(model)
    model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    print(model.qconfig)
    torch.quantization.prepare(model, inplace=True)

    print(f'Model after fusion(prepared): {model}')

    # Calibrate first
    print('Post Training Quantization Prepare: Inserting Observers')
    print('\n Inverted Residual Block:After observer insertion \n\n', model.conv1)

    # Calibrate with the training set
    evaluate(model, dtloader, eval_batches=calibration_batches)
    print('Post Training Quantization: Calibration done')

    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)
    print('Post Training Quantization: Convert done')
    print('\n Inverted Residual Block: After fusion and quantization, note fused modules: \n\n', model.conv1)

    print("Size of model after quantization")
    print_size_of_model(model)
    script = torch.jit.script(model)
    path_tosave = join(dirname(abspath(__file__)), saved_model_dir, scripted_quantized_model_file)
    print(f'path to save: {path_tosave}')
    with open(path_tosave, 'wb') as f:
        torch.save(model.state_dict(), f)

    print(f'model after quantization (prepared and converted:) {model}')
    # torch.jit.save(script, path_tosave)

dataset = FakeData(1000, image_size=(3, 112, 112), num_classes=5, transform=transforms.ToTensor())
data_loader = DataLoader(dataset, batch_size=1)

# quantize the model 
model = resnet18()
# model = SimpleNetwork()
quantize_model(model, data_loader)

# and load and test the quantized model
model = resnet18()
# model = SimpleNetwork()
test_the_model(model, data_loader)