Quantization.convert after QAT pickling issue

Hello

I’ve been having an issue with torch.quantization.convert after performing QAT -

I modified the model (face detector) to do QAT by adding the lines

    net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    torch.quantization.prepare_qat(net, inplace=True)

in train.py and the QuantStub/DeStub in the forward() of Mb_Tiny_RFB() (vision/nn/mb_tiny_rfb.py) following this tutorial, and then saved the model via torch.save(self.state_dict(), path)

train.py
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")

if __name__ == '__main__':
    timer = Timer()
    create_net = create_Mb_Tiny_RFB_fd

    train_transform = TrainAugmentation(config.image_size, config.image_mean, config.image_std)
    target_transform = MatchPrior(config.priors, config.center_variance,
                                  config.size_variance, args.overlap_threshold)

    test_transform = TestTransform(config.image_size, config.image_mean_test, config.image_std)

    datasets = []
    for dataset_path in args.datasets:
        if args.dataset_type == 'voc':
            dataset = VOCDataset(dataset_path, transform=train_transform,
                                 target_transform=target_transform, img_size = config.image_size)
            label_file = os.path.join(args.checkpoint_folder, "voc-model-labels.txt")
            store_labels(label_file, dataset.class_names)
            num_classes = len(dataset.class_names)
        else:
            raise ValueError(f"Dataset type {args.dataset_type} is not supported.")
        datasets.append(dataset)
    train_dataset = ConcatDataset(datasets)

    train_loader = DataLoader(train_dataset, args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=True, pin_memory=True)
    val_dataset = VOCDataset(args.validation_dataset, transform=test_transform,
                                 target_transform=target_transform, is_test=True)
    val_loader = DataLoader(val_dataset, args.batch_size,
                            num_workers=args.num_workers,
                            shuffle=False)

    net = create_net(num_classes)

    min_loss = -10000.0
    last_epoch = -1

    base_net_lr = args.base_net_lr if args.base_net_lr is not None else args.lr
    extra_layers_lr = args.extra_layers_lr if args.extra_layers_lr is not None else args.lr
 
    params = [
            {'params': net.base_net.parameters(), 'lr': base_net_lr},
            {'params': itertools.chain(
                net.source_layer_add_ons.parameters(),
                net.extras.parameters()
            ), 'lr': extra_layers_lr},
            {'params': itertools.chain(
                net.regression_headers.parameters(),
                net.classification_headers.parameters()
            )}
        ]

    if args.resume:
        logging.info(f"Resume from the model {args.resume}")
        net.load(args.resume)

    criterion = MultiboxLoss(config.priors, neg_pos_ratio=3,
                             center_variance=0.1, size_variance=0.2, device=DEVICE)
    optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    ...
    net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    torch.quantization.prepare_qat(net, inplace=True)

    net.to(DEVICE)

    for epoch in range(last_epoch + 1, args.num_epochs):
        train(train_loader, net, criterion, optimizer,
              device=DEVICE, debug_steps=args.debug_steps, epoch=epoch)
        if epoch > 3:
            # Freeze quantizer parameters
            net.apply(torch.quantization.disable_observer)
        if epoch > 2:
            # Freeze batch norm mean and variance estimates towards the end of training to better match inference numerics.
            net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

        if epoch % args.validation_epochs == 0 or epoch == args.num_epochs - 1:
            logging.info("lr rate :{}".format(optimizer.param_groups[0]['lr']))
            val_loss, val_regression_loss, val_classification_loss = test(val_loader, net, criterion, DEVICE)

            net.eval()
            quant_model = torch.quantization.convert(net.cpu(), inplace=False) # <-- error happens here

            model_path = os.path.join(args.checkpoint_folder, f"{args.net}-Epoch-{epoch}-Loss-{val_loss}.pth")
            net.save(model_path)

When I tried to call quantization.convert() before saving, I got the error:

Traceback (most recent call last):
  File "train.py", line 432, in <module>
    quant_model = torch.quantization.convert(net.module.eval().cpu(), inplace=False)
  File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/site-packages/torch/quantization/quantize.py", line 299, in convert
    module = copy.deepcopy(module)
  File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 270, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 146, in deepcopy
    y = copier(x, memo)
  File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 230, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: cannot pickle 'module' object

So instead I tried to load the QAT’d parameters into a less confusing form of the model and then tried converting again, but got the same error:

import torchvision
from torch import nn
from vision.utils import box_utils
from vision.ssd.config.fd_config import define_img_size
define_img_size(640)
from vision.ssd.mb_tiny_RFB_fd import create_Mb_Tiny_RFB_fd
import torch.nn.functional as F
import cv2
import numpy as np

class_names = ['background', 'face']
net_1 = create_Mb_Tiny_RFB_fd(len(class_names), is_test=True, device='cpu')
net_1.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(net_1, inplace=True)

# load definition: self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
net_1.load(model_path) # load the previously QAT'd model (without quantisation conversion)

class SimpleNet(nn.Module):
    def __init__(self, base_net, regression_headers, classification_headers, extras, priors, config):
        super(SimpleNet, self).__init__()
        self.backbone0 = base_net[:8]
        self.backbone1 = base_net[8:11]
        self.backbone2 = base_net[11:13]
        self.last_chunk = base_net[13:] 
        self.regression_headers0 = regression_headers[0]
        self.regression_headers1 = regression_headers[1]
        self.regression_headers2 = regression_headers[2]
        self.regression_headers3 = regression_headers[3]
        self.classification_headers0 = classification_headers[0]
        self.classification_headers1 = classification_headers[1]
        self.classification_headers2 = classification_headers[2]
        self.classification_headers3 = classification_headers[3]
        self.extras = extras
        self.num_classes = 2
        self.priors = priors
        self.config = config
        self.last_op = nn.Softmax(dim=-1)

    def forward(self, x):
        confidences = []
        locations = []
        x = self.backbone0(x)
        confidence = self.classification_headers0(x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
        location = self.regression_headers0(x)
        location = location.permute(0, 2, 3, 1).contiguous()
        location = location.view(location.size(0), -1, 4)
        confidences.append(confidence)
        locations.append(location)

        x = self.backbone1(x)
        confidence = self.classification_headers1(x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
        location = self.regression_headers1(x)
        confidences.append(confidence)
        locations.append(location)

        x = self.backbone2(x)
        confidence = self.classification_headers2(x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
        location = self.regression_headers2(x)
        confidences.append(confidence)
        locations.append(location)

        x = self.last_chunk.forward(x)

        x = self.extras(x)
        confidence = self.classification_headers3(x)
        confidence = confidence.permute(0, 2, 3, 1).contiguous()
        confidence = confidence.view(confidence.size(0), -1, self.num_classes)
        location = self.regression_headers3(x)
        confidences.append(confidence)
        locations.append(location)

        confidences = torch.cat(confidences, 1)
        confidences = self.last_op(confidences)
        locations = torch.cat(locations, 1)

        boxes = box_utils.convert_locations_to_boxes(
            locations, self.priors, torch.tensor([0.1]), torch.tensor([0.2]) #self.config.center_variance, self.config.size_variance
        )
        boxes = box_utils.center_form_to_corner_form(boxes)
        return confidences, boxes


model = SimpleNet(
        net_1.base_net,
        net_1.regression_headers,
        net_1.classification_headers,
        net_1.extras[0],
        net_1.priors,
        net_1.config)

model.eval()
model = torch.quantization.convert(model, inplace=False) # error here
Heres the `print(model)` output:
SimpleNet(
  (backbone0): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (4): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (5): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (6): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (7): BasicRFB(
      (branch0): Sequential(
        (0): BasicConv(
          (conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        )
        (1): BasicConv(
          (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
          (activation_fn): ReLU()
        )
        (2): BasicConv(
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
          (bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
      (branch1): Sequential(
        (0): BasicConv(
          (conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        )
        (1): BasicConv(
          (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
          (activation_fn): ReLU()
        )
        (2): BasicConv(
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
          (bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
      (branch2): Sequential(
        (0): BasicConv(
          (conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn): BatchNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        )
        (1): BasicConv(
          (conv): Conv2d(8, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(12, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
          (activation_fn): ReLU()
        )
        (2): BasicConv(
          (conv): Conv2d(12, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
          (activation_fn): ReLU()
        )
        (3): BasicConv(
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
          (bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
      (ConvLinear): BasicConv(
        (conv): Conv2d(48, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      )
      (shortcut): BasicConv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
      )
      (activation_fn): ReLU()
    )
  )
  (backbone1): Sequential(
    (8): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (9): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (10): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
  )
  (backbone2): Sequential(
    (11): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (12): Sequential(
      (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
  )
  (last_chunk): Sequential()
  (regression_headers0): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
    (1): ReLU()
    (2): Conv2d(64, 12, kernel_size=(1, 1), stride=(1, 1))
  )
  (regression_headers1): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
    (1): ReLU()
    (2): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1))
  )
  (regression_headers2): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
    (1): ReLU()
    (2): Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))
  )
  (regression_headers3): Conv2d(256, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (classification_headers0): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
    (1): ReLU()
    (2): Conv2d(64, 6, kernel_size=(1, 1), stride=(1, 1))
  )
  (classification_headers1): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
    (1): ReLU()
    (2): Conv2d(128, 4, kernel_size=(1, 1), stride=(1, 1))
  )
  (classification_headers2): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
    (1): ReLU()
    (2): Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))
  )
  (classification_headers3): Conv2d(256, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (extras): Sequential(
    (0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64)
      (1): ReLU()
      (2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
    )
    (3): ReLU()
  )
  (last_op): Softmax(dim=-1)
)

It appeared that something could not be pickled in the model - so I tried using dill:

dill.detect.trace(True)
dill.detect.errors(model)

and the output is

Output
T2: <class '__main__.SimpleNet'>
F2: <function _create_type at 0x7f574cb6a670>
# F2
T1: <class 'type'>
F2: <function _load_type at 0x7f574cb6a5e0>
# F2
# T1
T4: <class 'torch.nn.modules.module.Module'>
# T4
D2: <dict object at 0x7f574c9f4b40>
F1: <function SimpleNet.__init__ at 0x7f574cb86f70>
F2: <function _create_function at 0x7f574cb6a700>
# F2
Co: <code object __init__ at 0x7f57af0eb240, file "simple_ul.py", line 38>
F2: <function _create_code at 0x7f574cb6a790>
# F2
# Co
D1: <dict object at 0x7f57af1d7f00>
# D1
Ce: <cell at 0x7f574cb5c790: type object at 0x5637763f3840>
F2: <function _create_cell at 0x7f574cb6ab80>
# F2
T5: <class '__main__.SimpleNet'>
# T5
# Ce
D2: <dict object at 0x7f574c9f48c0>
# D2
# F1
F1: <function SimpleNet.forward at 0x7f574cb86ee0>
Co: <code object forward at 0x7f57af0f9450, file "simple_ul.py", line 60>
# Co
D1: <dict object at 0x7f57af1d7f00>
# D1
D2: <dict object at 0x7f574c9f4a00>
# D2
# F1
# D2
# T2
D2: <dict object at 0x7f574c9fa740>
T4: <class 'collections.OrderedDict'>
# T4
T4: <class 'torch.nn.modules.container.Sequential'>
# T4
D2: <dict object at 0x7f574c9fa4c0>
D2: <dict object at 0x7f574d8c7580>
T4: <class 'torch.nn.qat.modules.conv.Conv2d'>
# T4
D2: <dict object at 0x7f574ca8db40>
F2: <function _rebuild_parameter at 0x7f57ad6594c0>
# F2
F2: <function _rebuild_tensor_v2 at 0x7f57ad659280>
# F2
/home/user/anaconda3/envs/FaceDetector/lib/python3.8/site-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead
  warnings.warn("pickle support for Storage will be removed in 1.5. Use `torch.save` instead", FutureWarning)
F2: <function _load_from_bytes at 0x7f5759b7b310>
# F2
T4: <class 'torch.quantization.fake_quantize.FakeQuantize'>
# T4
D2: <dict object at 0x7f5759343500>
T4: <class 'torch.quantization.observer.MovingAverageMinMaxObserver'>
# T4
D2: <dict object at 0x7f574d8b7ac0>
# D2
# D2
D2: <dict object at 0x7f574ca8de00>
T4: <class 'torch.quantization.observer.MovingAveragePerChannelMinMaxObserver'>
# T4
D2: <dict object at 0x7f574ca8de40>
# D2
# D2
T6: <class 'torch.quantization.qconfig.QConfig'>
F2: <function _create_namedtuple at 0x7f574cb6f0d0>
# F2
# T6
T4: <class 'torch.quantization.observer._with_args.<locals>._PartialWrapper'>
but I don't really understand the output and how to solve this problem... Am I doing QAT -> quantisation correctly? Is it correct to again set the qconfig and to prepare_qat before loading a QAT'd model? Any help would be greatly appreciated.

looks like it’s failing to copy.deepcopy(module). Just to confirm, does copy.deepcopy work on your model instance before you do QAT?

Hello, yea I realised that deepcopy did not work on my original model either, and found the issue - I had some unpicklable objects saved in the init of my model

Thank you!

could you share what’s in your original model causing deepcopy to fail? I am having the same problem and need some clues to fix it. Thanks for your help.