Inplace error with QAT

When I’m doing QAT on my network, I keep getting the in-place error.(and I don’t set any Inplace=True in my code) I set detect anomaly True, and try to find the bug. It returns that it’s the conv2d layer leads to this error. However, it’s weird to believe that conv2d leads to this error. Can anyone help?

the error message is below:

/usr/local/lib/python3.8/dist-packages/torch/autograd/init.py:130: UserWarning: Error detected in FakeQuantizePerChannelAffineBackward. Traceback of forward call that caused the error:
File “Train_CS_OPINE_Net_plus_quan.py”, line 299, in
[x_output, loss_layers_sym, Phi, x_init] = model_fp32(batch_x)
File “/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “Train_CS_OPINE_Net_plus_quan.py”, line 218, in forward
[x_final, layer_sym] = self.fcs[0](x, PhiWeight, PhiTWeight, PhiTb) #share weight
File “/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “Train_CS_OPINE_Net_plus_quan.py”, line 160, in forward
x_conv2_backward_relu = self.convrelu2_backward(x_conv1_backward_relu)
File “/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/usr/local/lib/python3.8/dist-packages/torch/nn/intrinsic/qat/modules/conv_fused.py”, line 312, in forward
self._conv_forward(input, self.weight_fake_quant(self.weight)))
File “/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/usr/local/lib/python3.8/dist-packages/torch/quantization/fake_quantize.py”, line 99, in forward
X = torch.fake_quantize_per_channel_affine(X, self.scale, self.zero_point,
(Triggered internally at /pytorch/torch/csrc/autograd/python_anomaly_mode.cpp:104.)
Variable._execution_engine.run_backward(
Traceback (most recent call last):
File “Train_CS_OPINE_Net_plus_quan.py”, line 319, in
loss_all.backward(retain_graph = True)
File “/usr/local/lib/python3.8/dist-packages/torch/tensor.py”, line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File “/usr/local/lib/python3.8/dist-packages/torch/autograd/init.py”, line 130, in backward
Variable._execution_engine.run_backward(

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

this is my code:

class BasicBlock(torch.nn.Module):
    def __init__(self):
        super(BasicBlock, self).__init__()

        self.lambda_step = nn.Parameter(torch.Tensor([0.5]), requires_grad = True)
        self.soft_thr = nn.Parameter(torch.Tensor([0.01]), requires_grad = True)

        qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
        # torch.nn.qat.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None)
        # torch.nn.intrinsic.qat.ConvReLU2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', qconfig=None)
        
        self.conv_D = nn.qat.Conv2d(1, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)
        
        self.convrelu1_forward = nn.intrinsic.qat.ConvReLU2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)

        self.conv2_forward = nn.qat.Conv2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)
        self.convrelu1_backward = nn.intrinsic.qat.ConvReLU2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)

        self.convrelu2_backward = nn.intrinsic.qat.ConvReLU2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)
        self.conv2_backward = nn.qat.Conv2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)
        self.conv2_backward.weight_fake_quant = self.convrelu2_backward.weight_fake_quant # convrelu2_backward is fused with relu, but I only want the conv (without relu), and take out the weight

        self.convrelu1_G = nn.intrinsic.qat.ConvReLU2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)

        self.conv2_G = nn.qat.Conv2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)

        self.conv3_G = nn.qat.Conv2d(32, 1,(3, 3), padding=1, bias=False, qconfig = qconfig)

        if ReLU == "trelu":
            self.alpha = nn.Parameter(torch.Tensor([alpha_arg]), requires_grad = False)
        
        self.skip_add = nn.quantized.FloatFunctional()
    
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self, x, PhiWeight, PhiTWeight, PhiTb):
        x_quan = self.quant(x)
        PhiTb_quan = self.quant(PhiTb)
        PhiWeight_quan = self.quant(PhiWeight)
        PhiTWeight_quan = self.quant(PhiTWeight)

        temp = F.conv2d(x_quan, PhiWeight_quan, padding=0,stride=33, bias=None)
        temp1 = F.conv2d(temp, PhiTWeight_quan, padding=0, bias=None)
        x1 = x_quan - self.lambda_step * torch.nn.PixelShuffle(33)(temp1)
        x_input = x1 + self.lambda_step * PhiTb_quan

        x_conv_D = self.conv_D(x_input) 
        
        x_conv1_forward_relu = self.convrelu1_forward(x_conv_D)
        
        x_conv2_forward = self.conv2_forward(x_conv1_forward_relu)
        x_soft = torch.mul(torch.sign(x_conv2_forward), F.relu(torch.abs(x_conv2_forward) - self.soft_thr))

        x_conv1_backward_relu = self.convrelu1_backward(x_soft)
        # x_conv1_backward_relu = x_soft + x_soft   

        x_conv2_backward_relu = self.convrelu2_backward(x_conv1_backward_relu)
        
        # x_conv2_backward_relu = x_conv1_backward_relu + x_conv1_backward_relu

        x_conv1_G_relu = self.convrelu1_G(x_conv2_backward_relu)

        x_conv2_G = self.conv2_G(x_conv1_G_relu)
        x_conv3_G = self.conv3_G(x_conv2_G)

        x_pred_quan = x_input+x_conv3_G
        # x_pred_quan = self.skip_add.add(x_input, x_conv3_G)

#######################For caculation of SSIM ###################### 
        x2 = self.convrelu1_backward(x_conv2_forward)
        # x2 = x_conv2_forward
        x_D_est = self.conv2_backward(x2) 
        # x_D_est = x2
        symloss_quan = x_D_est - x_conv_D
################################################################## 
        x_pred = self.dequant(x_pred_quan)
        symloss = self.dequant(symloss_quan)
        
        return [x_pred, symloss]

class OPINENetplus(torch.nn.Module):

def __init__(self, LayerNo, n_input):

    super(OPINENetplus, self).__init__()

    self.Phi = nn.Parameter(init.xavier_normal_(torch.Tensor(n_input, 1089)))

    self.Phi_scale = nn.Parameter(torch.Tensor([0.01]))

    onelayer = []

    self.LayerNo = LayerNo

    # for i in range(LayerNo):

    for i in range(1): #share weight

        onelayer.append(BasicBlock())

    self.fcs = nn.ModuleList(onelayer)

    

def forward(self, x):

    # Sampling-subnet

    Phi_ = MyBinarize(self.Phi)

    Phi = self.Phi_scale * Phi_

    PhiWeight = Phi.contiguous().view(n_input, 1, 33, 33) #Reshape Phi in order to use non-overlapping conv.

    Phix = F.conv2d(x, PhiWeight, padding=0, stride=33, bias=None)    # Get measurements

    # Initialization-subnet

    PhiTWeight = Phi.t().contiguous().view(n_output, n_input, 1, 1)

    PhiTb = F.conv2d(Phix, PhiTWeight, padding=0, bias=None)

    PhiTb = torch.nn.PixelShuffle(33)(PhiTb)

    x = PhiTb    # Conduct initialization

    x_init = x

    # Recovery-subnet

    layers_sym = []   # for computing symmetric loss

    for i in range(self.LayerNo):

        [x, layer_sym] = self.fcs[0](x, PhiWeight, PhiTWeight, PhiTb) #share weight

        layers_sym.append(layer_sym)

        

    x_final = x

    return [x_final, layers_sym, Phi, x_init]

I solved it by myself!
Because I call the self.convrelu1_backward for two times, the weights of self.convrelu1_backward is replaced when performing back-propagation. Later, I use different naming(or objects) for the two calling, and the error is solved!

self.convrelu1_backward = nn.intrinsic.qat.ConvReLU2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)

        self.convrelu2_backward = nn.intrinsic.qat.ConvReLU2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)

        self.conv1_backward = nn.intrinsic.qat.ConvReLU2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)
        value1 = self.convrelu1_backward.weight_fake_quant # convrelu1_backward is fused with relu, but I only want the conv (without relu), and take out the weight
        self.conv1_backward.weight_fake_quant = copy.deepcopy(value1)
        self.conv2_backward =nn.qat.Conv2d(32, 32,(3, 3), padding=1, bias=False, qconfig = qconfig)
        value2 = self.convrelu2_backward.weight_fake_quant # convrelu2_backward is fused with relu, but I only want the conv (without relu), and take out the weight
        self.conv2_backward.weight_fake_quant = copy.deepcopy(value2)

I have the same issue, but my head is shared head and it must used many times, How should I solve this problem? just not quantization this module?

Hi Wang-zipeng,
You can make different copies by using copy.deepcopy so that the original gradient won’t be overwrote during backpropagation. Or what do you mean that “head is shared head” ? Can you paste your code here?

Thank you for your reply sir. It’s rpn_head shared by different fpn’s output in faster-rcnn. I think you know that network and I used the implementation in the mmdetection. The code are as follows:

    def forward_single(self, x):
        """Forward feature map of a single scale level."""
        x = self.rpn_conv(x)
        x = F.relu(x, inplace=True)
        rpn_cls_score = self.rpn_cls(x)
        rpn_bbox_pred = self.rpn_reg(x)
        return rpn_cls_score, rpn_bbox_pred

and it’s a for-loop shared by different fpn layers. If I can copy the layers, could I get the correct gradients between different copy?

Hi Wang-zipeng,

Unfortunately, I don’t know faster-rcnn. However, I think that the each copy’s gradient will be updated (and it’s depend on the same loss function). Each copy’s gradient should be the same after updating.

By the way, I suggest you use the nn.ReLU instead of F.ReLU because nn._ can be used to fusing layer during QAT process.

I hope this can be helpful for you. If any question, welcome to discuss with me!