Quantization-aware training with DataParallel crashes randomly

(Looks like an undefined behavior)
Torch 1.7.0, python 3.7, conda installation

I am trying to apply QAT to EfficientNet. Everything works good until I wrap my model into DataParallel. Then it crashes randomly with AssertionError during calculate_q_params call, most of all (~90%) during my first train batch run, sometimes after 5-10 runs.
First, I assumed that some weights burst to NaN, but

  1. why everything is ok without DataParallel
  2. the initial weights are not NaN, but in most cases it cannot do a single forward.

It never happens during evaluation. The layers the script crashes on are always different.

I can provide the full code via github, if needed - I could not create a compact minimum, it was not reproduced with simple networks.

Three stack trace examples (the last two are shortened, because the first lines are equal)
#1

Traceback (most recent call last):
  File "/home/researcher/project_src/src/quantization/qat.py", line 161, in <module>
    main(main_settings, qat_settings)
  File "/home/researcher/project_src/src/quantization/qat.py", line 126, in main
    model_input_feature=train_settings.model_input_feature)
  File "/home/researcher/project_src/src/train.py", line 34, in train_one_epoch
    outputs = model(_input)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 161, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 171, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/_utils.py", line 428, in reraise
    raise self.exc_type(msg)
AssertionError: Caught AssertionError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/single_frame_template.py", line 34, in forward
    _, embedding = self.feature_extractor(x)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 127, in forward
    features, x = self.extract_features(inputs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 114, in extract_features
    x = block(x, drop_connect_rate=drop_connect_rate)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/mb_conv_block.py", line 81, in forward
    x_squeezed = self._se_expand(self._activation(self._se_reduce(x_squeezed)))
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/utils.py", line 34, in forward
    x = self.conv2d(x)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 731, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/quantize.py", line 82, in _observer_forward_hook
    return self.activation_post_process(output)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 90, in forward
    _scale, _zero_point = self.calculate_qparams()
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 85, in calculate_qparams
    return self.activation_post_process.calculate_qparams()
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 402, in calculate_qparams
    return self._calculate_qparams(self.min_val, self.max_val)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 248, in _calculate_qparams
    min_val, max_val
AssertionError: min 0.23065748810768127 should be less than max 0.23042967915534973

#2

AssertionError: Caught AssertionError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/single_frame_template.py", line 34, in forward
    _, embedding = self.feature_extractor(x)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 127, in forward
    features, x = self.extract_features(inputs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 106, in extract_features
    x = self._activation_stem(self._bn0(self._conv_stem(inputs)))
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/utils.py", line 34, in forward
    x = self.conv2d(x)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/qat/modules/conv.py", line 32, in forward
    return self._conv_forward(input, self.weight_fake_quant(self.weight))
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 90, in forward
    _scale, _zero_point = self.calculate_qparams()
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 85, in calculate_qparams
    return self.activation_post_process.calculate_qparams()
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 632, in calculate_qparams
    return self._calculate_qparams(self.min_vals, self.max_vals)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 252, in _calculate_qparams
    min_val, max_val
AssertionError: min tensor([ -477068.8750,   -36602.4375,  -153710.1406,  -673136.7500,
        -1487181.5000,  -887609.8750, -2090753.6250, -3073480.7500,
         -135975.0625, -1857073.3750,  -181464.7656,  -228204.1094,
          429050.0938,  -678035.3750,  1770874.6250,  -940542.6250,
          -34900.1250,   -92754.3359,  -276413.1875,  -346014.0938,
         -771024.1875,           nan, -2376894.5000, -2290334.0000,
          -96295.6406,  -910624.7500,   -20395.1934,   411769.1875,
         -206277.0938,  -126536.9297,  -240083.9688,           nan],
       device='cuda:0') should be less than max tensor([  262472.2500,   523413.2500,   206003.6875,   138193.0625,
         -625572.7500,  -141716.9844,  -521715.0000, -1475260.0000,
         1374850.2500,    87139.2031,    52821.7578,   156602.1094,
         1140287.7500,   897314.3750,  3182818.0000,   274146.6562,
          445342.1875,   458844.4688,   382576.7812,   250239.7031,
          378798.8438,           nan,    70838.8828,   -85726.8125,
          107011.6406,  -338712.8125,   377927.9375,  2790590.2500,
           95255.9531,   304543.9688,   215083.4688,           nan],
       device='cuda:0')

#3

AssertionError: Caught AssertionError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/single_frame_template.py", line 34, in forward
    _, embedding = self.feature_extractor(x)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 127, in forward
    features, x = self.extract_features(inputs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/model.py", line 114, in extract_features
    x = block(x, drop_connect_rate=drop_connect_rate)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/project_src/lib/models/qfriendly/template_efficientnet/mb_conv_block.py", line 76, in forward
    x = self._activation(self._bn1(self._depthwise_conv(x)))
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 731, in _call_impl
    hook_result = hook(self, input, result)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/quantize.py", line 82, in _observer_forward_hook
    return self.activation_post_process(output)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 90, in forward
    _scale, _zero_point = self.calculate_qparams()
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 85, in calculate_qparams
    return self.activation_post_process.calculate_qparams()
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 402, in calculate_qparams
    return self._calculate_qparams(self.min_val, self.max_val)
  File "/home/researcher/miniconda3/lib/python3.7/site-packages/torch/quantization/observer.py", line 248, in _calculate_qparams
    min_val, max_val
AssertionError: min nan should be less than max nan

Hi @e_sh, if there is a repro you can link to that would be awesome.

I’m discussing it now with my colleagues. It is a part of a big project that cannot be put in a free access. I am trying to pull out the necessary code from there so that the issue is reproduced

Repro is here: https://bitbucket.org/ms_lilibeth/qat_dataparallel_bug/src/master/

I’ve reproduced the bug only with my custom implementation of EfficientNet. In order to make it quantization-friendly I’ve done the following:

  1. Multiplication and addition operations were replaced with
    quantization-friendly FloatFunctional.
  2. Other MBConvBlocks are used: they contain Conv2D instance instead of conv2d() function call
  3. Paddings changed to be symmetrical about kernel size (3x3 kernel => 1x1 padding; 5x5 kernel => 2x2 padding). No more ZeroPad2D

Run the script multiple times: sometimes the bug appears, sometimes not. The original implementation of EfficientNet never fails

Thank you, we will take a look.

Hi @e_sh, apoligies for the long delay, and thank you so much for providing the repro. Your repro helped us find a bug in PyTorch. The issue was that for per-channel observers, replication was not working properly in some cases in nn.DataParallel, specifically when the scale and zero_point buffers were not initialized. We have a fix upcoming in fix unflatten_dense_tensor when there is empty tensor inside by zhaojuanmao · Pull Request #50321 · pytorch/pytorch · GitHub. A short term workaround you could do before the fix lands is to run an input through the network before replicating it - this will populate all the quantization buffers.

As an aside, I’d recommend looking into using nn.DistributedDataParallel, it is generally recommended over nn.DataParallel.