Graph mode static quantization fails to quantize like eager mode static quantization

Hello everyone, hope you are having a great time.
I tried both Quantization approaches and noticed the Graph mode post-training static quantization does not work properly as the manual static quantization results in nearly 5x model size reduction and nearly 3x runtime speedup.
I’m using Pytorch 1.6.
and this is how I’m doing this :

import os
import pickle
import numpy as np
import torch

from torch.quantization import per_channel_dynamic_qconfig
from torch.quantization import quantize_dynamic_jit
from torch.quantization import get_default_qconfig
from torch.quantization import quantize_jit

from dataset import ArcDataset
from utils import benchmark, Benchmark_Block

def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

def quantize_post_training(jit_model_path, path_to_save, data_loader_test, dummy_input=torch.randn(size=(1, 3, 112, 112))):
    
    jit_model = torch.jit.load(jit_model_path)
    qconfig = get_default_qconfig('fbgemm')
    qconfig_dict = {'': qconfig}

    quantized_model = quantize_jit(jit_model, 
                                    qconfig_dict,
                                    calibrate, 
                                    [data_loader_test], 
                                    inplace=False, 
                                    debug=False) 

    torch.jit.save(quantized_model, path_to_save)
    print(f'Quantization is done!')
    
    with Benchmark_Block("Default Model: ") as blk:
        for i in range(100):
            _ = jit_model(dummy_input)
        
    with Benchmark_Block("Quantized Model: ") as blk:
        for i in range(100):
            _ = quantized_model(dummy_input)

    print(f'default model size: {os.path.getsize(jit_model_path)/1e6} MB')
    print(f'quantized model size: {os.path.getsize(path_to_save)/1e6} MB')

def run_quantization():
    train_dataset = ArcDataset(sample_count=1000)
    dtloader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4)
    model_path = "checkpoint_test.jit"
    model_save_path = "checkpoint_test_q.jit"
    quantize_post_training(model_path, model_save_path, dtloader)

run_quantization()

This finishes and the results are as follows :

torch version: 1.6.0
/root/anaconda3/envs/ShishoSama/lib/python3.7/site-packages/torch/nn/modules/module.py:385: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
  if param.grad is not None:
Quantization is done!
Default Model:  took 3216.330 ms [min/max: 3216.3/3216.3] ms
Quantized Model:  took 3799.413 ms [min/max: 3799.4/3799.4] ms
default model size: 22.936544 MB
quantized model size: 22.119849 MB

As you can see, nearly nothing is changed between the two models.
now running the static quantization in eager mode results in these numbers:

Default Model:  took 2975.428 ms [min/max: 2975.4/2975.4] ms
Size (MB): 22.853838
Quantized Model:  took 373.182 ms [min/max: 373.2/373.2] ms
Size (MB): 5.798671

based on the information provided here we should be able to achieve the same or very close to the same result we get in eager mode. so I wonder what is it that I’m doing wrong here.
Any help is g reatly appreciated.

Extra notes:
The base model is trained in pytorch 1.5.1 and the quantization process (both graph mode and eager mode) are being done using Pytorch 1.6.

OK, it seems like a bug in 1.6 only as upgrading to 1.7 fixed this issue!
here are the results in 1.7:
Graph mode static quantization :

Quantization is done!
Default Model:  took 895.521 ms [min/max: 895.5/895.5] ms
Quantized Model:  took 337.002 ms [min/max: 337.0/337.0] ms
default model size: 22.936544 MB
quantized model size: 5.780453 MB

Eager mode static quantization :

Default Model:  took 1211.988 ms [min/max: 1212.0/1212.0] ms
Size (MB): 22.853838
Quantized Model:  took 336.531 ms [min/max: 336.5/336.5] ms
Size (MB): 5.798671

Although the timing for eager mode static quantization is a bit off, I guess this is normal