Hello
I’ve been having an issue with torch.quantization.convert
after performing QAT -
I modified the model (face detector) to do QAT by adding the lines
net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(net, inplace=True)
in train.py
and the QuantStub/DeStub in the forward() of Mb_Tiny_RFB() (vision/nn/mb_tiny_rfb.py
) following this tutorial, and then saved the model via torch.save(self.state_dict(), path)
train.py
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() and args.use_cuda else "cpu")
if __name__ == '__main__':
timer = Timer()
create_net = create_Mb_Tiny_RFB_fd
train_transform = TrainAugmentation(config.image_size, config.image_mean, config.image_std)
target_transform = MatchPrior(config.priors, config.center_variance,
config.size_variance, args.overlap_threshold)
test_transform = TestTransform(config.image_size, config.image_mean_test, config.image_std)
datasets = []
for dataset_path in args.datasets:
if args.dataset_type == 'voc':
dataset = VOCDataset(dataset_path, transform=train_transform,
target_transform=target_transform, img_size = config.image_size)
label_file = os.path.join(args.checkpoint_folder, "voc-model-labels.txt")
store_labels(label_file, dataset.class_names)
num_classes = len(dataset.class_names)
else:
raise ValueError(f"Dataset type {args.dataset_type} is not supported.")
datasets.append(dataset)
train_dataset = ConcatDataset(datasets)
train_loader = DataLoader(train_dataset, args.batch_size,
num_workers=args.num_workers,
shuffle=True, pin_memory=True)
val_dataset = VOCDataset(args.validation_dataset, transform=test_transform,
target_transform=target_transform, is_test=True)
val_loader = DataLoader(val_dataset, args.batch_size,
num_workers=args.num_workers,
shuffle=False)
net = create_net(num_classes)
min_loss = -10000.0
last_epoch = -1
base_net_lr = args.base_net_lr if args.base_net_lr is not None else args.lr
extra_layers_lr = args.extra_layers_lr if args.extra_layers_lr is not None else args.lr
params = [
{'params': net.base_net.parameters(), 'lr': base_net_lr},
{'params': itertools.chain(
net.source_layer_add_ons.parameters(),
net.extras.parameters()
), 'lr': extra_layers_lr},
{'params': itertools.chain(
net.regression_headers.parameters(),
net.classification_headers.parameters()
)}
]
if args.resume:
logging.info(f"Resume from the model {args.resume}")
net.load(args.resume)
criterion = MultiboxLoss(config.priors, neg_pos_ratio=3,
center_variance=0.1, size_variance=0.2, device=DEVICE)
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay)
...
net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(net, inplace=True)
net.to(DEVICE)
for epoch in range(last_epoch + 1, args.num_epochs):
train(train_loader, net, criterion, optimizer,
device=DEVICE, debug_steps=args.debug_steps, epoch=epoch)
if epoch > 3:
# Freeze quantizer parameters
net.apply(torch.quantization.disable_observer)
if epoch > 2:
# Freeze batch norm mean and variance estimates towards the end of training to better match inference numerics.
net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
if epoch % args.validation_epochs == 0 or epoch == args.num_epochs - 1:
logging.info("lr rate :{}".format(optimizer.param_groups[0]['lr']))
val_loss, val_regression_loss, val_classification_loss = test(val_loader, net, criterion, DEVICE)
net.eval()
quant_model = torch.quantization.convert(net.cpu(), inplace=False) # <-- error happens here
model_path = os.path.join(args.checkpoint_folder, f"{args.net}-Epoch-{epoch}-Loss-{val_loss}.pth")
net.save(model_path)
When I tried to call quantization.convert()
before saving, I got the error:
Traceback (most recent call last):
File "train.py", line 432, in <module>
quant_model = torch.quantization.convert(net.module.eval().cpu(), inplace=False)
File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/site-packages/torch/quantization/quantize.py", line 299, in convert
module = copy.deepcopy(module)
File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 172, in deepcopy
y = _reconstruct(x, memo, *rv)
File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 270, in _reconstruct
state = deepcopy(state, memo)
File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 146, in deepcopy
y = copier(x, memo)
File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 230, in _deepcopy_dict
y[deepcopy(key, memo)] = deepcopy(value, memo)
File "/home/user/anaconda3/envs/FaceDetector/lib/python3.8/copy.py", line 161, in deepcopy
rv = reductor(4)
TypeError: cannot pickle 'module' object
So instead I tried to load the QAT’d parameters into a less confusing form of the model and then tried converting again, but got the same error:
import torchvision
from torch import nn
from vision.utils import box_utils
from vision.ssd.config.fd_config import define_img_size
define_img_size(640)
from vision.ssd.mb_tiny_RFB_fd import create_Mb_Tiny_RFB_fd
import torch.nn.functional as F
import cv2
import numpy as np
class_names = ['background', 'face']
net_1 = create_Mb_Tiny_RFB_fd(len(class_names), is_test=True, device='cpu')
net_1.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(net_1, inplace=True)
# load definition: self.load_state_dict(torch.load(model, map_location=lambda storage, loc: storage))
net_1.load(model_path) # load the previously QAT'd model (without quantisation conversion)
class SimpleNet(nn.Module):
def __init__(self, base_net, regression_headers, classification_headers, extras, priors, config):
super(SimpleNet, self).__init__()
self.backbone0 = base_net[:8]
self.backbone1 = base_net[8:11]
self.backbone2 = base_net[11:13]
self.last_chunk = base_net[13:]
self.regression_headers0 = regression_headers[0]
self.regression_headers1 = regression_headers[1]
self.regression_headers2 = regression_headers[2]
self.regression_headers3 = regression_headers[3]
self.classification_headers0 = classification_headers[0]
self.classification_headers1 = classification_headers[1]
self.classification_headers2 = classification_headers[2]
self.classification_headers3 = classification_headers[3]
self.extras = extras
self.num_classes = 2
self.priors = priors
self.config = config
self.last_op = nn.Softmax(dim=-1)
def forward(self, x):
confidences = []
locations = []
x = self.backbone0(x)
confidence = self.classification_headers0(x)
confidence = confidence.permute(0, 2, 3, 1).contiguous()
confidence = confidence.view(confidence.size(0), -1, self.num_classes)
location = self.regression_headers0(x)
location = location.permute(0, 2, 3, 1).contiguous()
location = location.view(location.size(0), -1, 4)
confidences.append(confidence)
locations.append(location)
x = self.backbone1(x)
confidence = self.classification_headers1(x)
confidence = confidence.permute(0, 2, 3, 1).contiguous()
confidence = confidence.view(confidence.size(0), -1, self.num_classes)
location = self.regression_headers1(x)
confidences.append(confidence)
locations.append(location)
x = self.backbone2(x)
confidence = self.classification_headers2(x)
confidence = confidence.permute(0, 2, 3, 1).contiguous()
confidence = confidence.view(confidence.size(0), -1, self.num_classes)
location = self.regression_headers2(x)
confidences.append(confidence)
locations.append(location)
x = self.last_chunk.forward(x)
x = self.extras(x)
confidence = self.classification_headers3(x)
confidence = confidence.permute(0, 2, 3, 1).contiguous()
confidence = confidence.view(confidence.size(0), -1, self.num_classes)
location = self.regression_headers3(x)
confidences.append(confidence)
locations.append(location)
confidences = torch.cat(confidences, 1)
confidences = self.last_op(confidences)
locations = torch.cat(locations, 1)
boxes = box_utils.convert_locations_to_boxes(
locations, self.priors, torch.tensor([0.1]), torch.tensor([0.2]) #self.config.center_variance, self.config.size_variance
)
boxes = box_utils.center_form_to_corner_form(boxes)
return confidences, boxes
model = SimpleNet(
net_1.base_net,
net_1.regression_headers,
net_1.classification_headers,
net_1.extras[0],
net_1.priors,
net_1.config)
model.eval()
model = torch.quantization.convert(model, inplace=False) # error here
Heres the `print(model)` output:
SimpleNet(
(backbone0): Sequential(
(0): Sequential(
(0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
)
(1): Sequential(
(0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
(1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(2): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(3): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(4): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(5): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(6): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(7): BasicRFB(
(branch0): Sequential(
(0): BasicConv(
(conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
(1): BasicConv(
(conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(activation_fn): ReLU()
)
(2): BasicConv(
(conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
(branch1): Sequential(
(0): BasicConv(
(conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
(1): BasicConv(
(conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(activation_fn): ReLU()
)
(2): BasicConv(
(conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3), dilation=(3, 3), bias=False)
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
(branch2): Sequential(
(0): BasicConv(
(conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(8, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
(1): BasicConv(
(conv): Conv2d(8, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(12, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(activation_fn): ReLU()
)
(2): BasicConv(
(conv): Conv2d(12, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
(activation_fn): ReLU()
)
(3): BasicConv(
(conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(5, 5), dilation=(5, 5), bias=False)
(bn): BatchNorm2d(16, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
)
(ConvLinear): BasicConv(
(conv): Conv2d(48, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
(shortcut): BasicConv(
(conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
)
(activation_fn): ReLU()
)
)
(backbone1): Sequential(
(8): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64, bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(9): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(10): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128, bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
)
(backbone2): Sequential(
(11): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=128, bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
(12): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256, bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
)
)
(last_chunk): Sequential()
(regression_headers0): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
(1): ReLU()
(2): Conv2d(64, 12, kernel_size=(1, 1), stride=(1, 1))
)
(regression_headers1): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
(1): ReLU()
(2): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1))
)
(regression_headers2): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
(1): ReLU()
(2): Conv2d(256, 8, kernel_size=(1, 1), stride=(1, 1))
)
(regression_headers3): Conv2d(256, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(classification_headers0): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
(1): ReLU()
(2): Conv2d(64, 6, kernel_size=(1, 1), stride=(1, 1))
)
(classification_headers1): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
(1): ReLU()
(2): Conv2d(128, 4, kernel_size=(1, 1), stride=(1, 1))
)
(classification_headers2): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
(1): ReLU()
(2): Conv2d(256, 4, kernel_size=(1, 1), stride=(1, 1))
)
(classification_headers3): Conv2d(256, 6, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(extras): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): ReLU()
(2): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64)
(1): ReLU()
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(3): ReLU()
)
(last_op): Softmax(dim=-1)
)
It appeared that something could not be pickled in the model - so I tried using dill
:
dill.detect.trace(True)
dill.detect.errors(model)
and the output is
Output
T2: <class '__main__.SimpleNet'>
F2: <function _create_type at 0x7f574cb6a670>
# F2
T1: <class 'type'>
F2: <function _load_type at 0x7f574cb6a5e0>
# F2
# T1
T4: <class 'torch.nn.modules.module.Module'>
# T4
D2: <dict object at 0x7f574c9f4b40>
F1: <function SimpleNet.__init__ at 0x7f574cb86f70>
F2: <function _create_function at 0x7f574cb6a700>
# F2
Co: <code object __init__ at 0x7f57af0eb240, file "simple_ul.py", line 38>
F2: <function _create_code at 0x7f574cb6a790>
# F2
# Co
D1: <dict object at 0x7f57af1d7f00>
# D1
Ce: <cell at 0x7f574cb5c790: type object at 0x5637763f3840>
F2: <function _create_cell at 0x7f574cb6ab80>
# F2
T5: <class '__main__.SimpleNet'>
# T5
# Ce
D2: <dict object at 0x7f574c9f48c0>
# D2
# F1
F1: <function SimpleNet.forward at 0x7f574cb86ee0>
Co: <code object forward at 0x7f57af0f9450, file "simple_ul.py", line 60>
# Co
D1: <dict object at 0x7f57af1d7f00>
# D1
D2: <dict object at 0x7f574c9f4a00>
# D2
# F1
# D2
# T2
D2: <dict object at 0x7f574c9fa740>
T4: <class 'collections.OrderedDict'>
# T4
T4: <class 'torch.nn.modules.container.Sequential'>
# T4
D2: <dict object at 0x7f574c9fa4c0>
D2: <dict object at 0x7f574d8c7580>
T4: <class 'torch.nn.qat.modules.conv.Conv2d'>
# T4
D2: <dict object at 0x7f574ca8db40>
F2: <function _rebuild_parameter at 0x7f57ad6594c0>
# F2
F2: <function _rebuild_tensor_v2 at 0x7f57ad659280>
# F2
/home/user/anaconda3/envs/FaceDetector/lib/python3.8/site-packages/torch/storage.py:34: FutureWarning: pickle support for Storage will be removed in 1.5. Use `torch.save` instead
warnings.warn("pickle support for Storage will be removed in 1.5. Use `torch.save` instead", FutureWarning)
F2: <function _load_from_bytes at 0x7f5759b7b310>
# F2
T4: <class 'torch.quantization.fake_quantize.FakeQuantize'>
# T4
D2: <dict object at 0x7f5759343500>
T4: <class 'torch.quantization.observer.MovingAverageMinMaxObserver'>
# T4
D2: <dict object at 0x7f574d8b7ac0>
# D2
# D2
D2: <dict object at 0x7f574ca8de00>
T4: <class 'torch.quantization.observer.MovingAveragePerChannelMinMaxObserver'>
# T4
D2: <dict object at 0x7f574ca8de40>
# D2
# D2
T6: <class 'torch.quantization.qconfig.QConfig'>
F2: <function _create_namedtuple at 0x7f574cb6f0d0>
# F2
# T6
T4: <class 'torch.quantization.observer._with_args.<locals>._PartialWrapper'>