(Looks like an undefined behavior)
Torch 1.7.0, python 3.7, conda installation
I am trying to apply QAT to EfficientNet. Everything works good until I wrap my model into DataParallel. Then it crashes randomly with AssertionError during calculate_q_params call, most of all (~90%) during my first train batch run, sometimes after 5-10 runs.
First, I assumed that some weights burst to NaN, but
- why everything is ok without DataParallel
- the initial weights are not NaN, but in most cases it cannot do a single forward.
It never happens during evaluation. The layers the script crashes on are always different.
I can provide the full code via github, if needed - I could not create a compact minimum, it was not reproduced with simple networks.
Three stack trace examples (the last two are shortened, because the first lines are equal)
#1
Traceback (most recent call last):
File "/home/researcher/project_src/src/quantization/qat.py", line 161, in <module>
main(main_settings, qat_settings)
File "/home/researcher/project_src/src/quantization/qat.py", line 126, in main
model_input_feature=train_settings.model_input_feature)
File "/home/researcher/project_src/src/train.py", line 34, in train_one_epoch
outputs = model(_input)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
outputs = self.parallel_apply(replicas, inputs, kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
output.reraise()
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/_utils.py", line 428, in reraise
raise self.exc_type(msg)
AssertionError: Caught AssertionError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/single_frame_template.py", line 34, in forward
_, embedding = self.feature_extractor(x)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 127, in forward
features, x = self.extract_features(inputs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 114, in extract_features
x = block(x, drop_connect_rate=drop_connect_rate)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/mb_conv_block.py", line 81, in forward
x_squeezed = self._se_expand(self._activation(self._se_reduce(x_squeezed)))
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/utils.py", line 34, in forward
x = self.conv2d(x)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 731, in _call_impl
hook_result = hook(self, input, result)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/quantize.py", line 82, in _observer_forward_hook
return self.activation_post_process(output)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 90, in forward
_scale, _zero_point = self.calculate_qparams()
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 85, in calculate_qparams
return self.activation_post_process.calculate_qparams()
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 402, in calculate_qparams
return self._calculate_qparams(self.min_val, self.max_val)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 248, in _calculate_qparams
min_val, max_val
AssertionError: min 0.23065748810768127 should be less than max 0.23042967915534973
#2
AssertionError: Caught AssertionError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/single_frame_template.py", line 34, in forward
_, embedding = self.feature_extractor(x)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 127, in forward
features, x = self.extract_features(inputs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 106, in extract_features
x = self._activation_stem(self._bn0(self._conv_stem(inputs)))
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/utils.py", line 34, in forward
x = self.conv2d(x)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/qat/modules/conv.py", line 32, in forward
return self._conv_forward(input, self.weight_fake_quant(self.weight))
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 90, in forward
_scale, _zero_point = self.calculate_qparams()
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 85, in calculate_qparams
return self.activation_post_process.calculate_qparams()
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 632, in calculate_qparams
return self._calculate_qparams(self.min_vals, self.max_vals)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 252, in _calculate_qparams
min_val, max_val
AssertionError: min tensor([ -477068.8750, -36602.4375, -153710.1406, -673136.7500,
-1487181.5000, -887609.8750, -2090753.6250, -3073480.7500,
-135975.0625, -1857073.3750, -181464.7656, -228204.1094,
429050.0938, -678035.3750, 1770874.6250, -940542.6250,
-34900.1250, -92754.3359, -276413.1875, -346014.0938,
-771024.1875, nan, -2376894.5000, -2290334.0000,
-96295.6406, -910624.7500, -20395.1934, 411769.1875,
-206277.0938, -126536.9297, -240083.9688, nan],
device='cuda:0') should be less than max tensor([ 262472.2500, 523413.2500, 206003.6875, 138193.0625,
-625572.7500, -141716.9844, -521715.0000, -1475260.0000,
1374850.2500, 87139.2031, 52821.7578, 156602.1094,
1140287.7500, 897314.3750, 3182818.0000, 274146.6562,
445342.1875, 458844.4688, 382576.7812, 250239.7031,
378798.8438, nan, 70838.8828, -85726.8125,
107011.6406, -338712.8125, 377927.9375, 2790590.2500,
95255.9531, 304543.9688, 215083.4688, nan],
device='cuda:0')
#3
AssertionError: Caught AssertionError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/single_frame_template.py", line 34, in forward
_, embedding = self.feature_extractor(x)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 127, in forward
features, x = self.extract_features(inputs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 114, in extract_features
x = block(x, drop_connect_rate=drop_connect_rate)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/mb_conv_block.py", line 76, in forward
x = self._activation(self._bn1(self._depthwise_conv(x)))
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 731, in _call_impl
hook_result = hook(self, input, result)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/quantize.py", line 82, in _observer_forward_hook
return self.activation_post_process(output)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 90, in forward
_scale, _zero_point = self.calculate_qparams()
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 85, in calculate_qparams
return self.activation_post_process.calculate_qparams()
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 402, in calculate_qparams
return self._calculate_qparams(self.min_val, self.max_val)
File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 248, in _calculate_qparams
min_val, max_val
AssertionError: min nan should be less than max nan