The gradient value from custom backward is different from param.grad

Hello.

I wrote some custom backward function in transformer class.

And yes, it works!
But, when I tried to probe the weight gradient which is calculated in custom backward, the value is different from param.grad in same layer. (The norm of the gradient is much bigger than param.grad’s)

But the thing is that, model is trained well. (the loss reduces well, and test accuracy is same with the baseline code.)

My code is as below.


class Quantized_Linear(nn.Linear):
    def __init__(self, weight_quantize_module: Quantizer, act_quantize_module: Quantizer, weight_grad_quantize_module: Quantizer, act_grad_quantize_module: Quantizer,
                 in_features, out_features, bias=True):
        super(Quantized_Linear, self).__init__(in_features, out_features, bias=bias)
        self.weight_quantize_module = weight_quantize_module
        self.act_quantize_module = act_quantize_module
        self.weight_grad_quantize_module = weight_grad_quantize_module
        self.act_grad_quantize_module = act_grad_quantize_module

    def forward(self, input, block_num, epoch, iteration, device_id, layer_info):
        return _quantize_global.apply(block_num, epoch, iteration, device_id, layer_info, input, self.weight, self.bias, self.weight_quantize_module,
                                      self.act_quantize_module, self.weight_grad_quantize_module, self.act_grad_quantize_module)
    
class _quantize_global(torch.autograd.Function):
    @staticmethod
    def forward(ctx, block_num, epoch, iteration, device_id, layer_info, x, w, bias=None, w_qmodule=None, a_qmodule=None, w_g_qmodule=None, a_g_qmodule=None):
        #save for backward
        ctx.block_num = block_num
        ctx.iteration = iteration
        ctx.layer_info = layer_info
        ctx.g_qmodule = w_g_qmodule, a_g_qmodule
        ctx.reshape_3D_size = x.size() # x as 3D 
        ctx.has_bias = bias is not None
        ctx.epoch = epoch
        ctx.device_id=device_id
        
        x = x.view(-1, x.size(-1)) #reshape to 2D
        input_quant, s_input_quant = a_qmodule(x)
        weight_quant, s_weight_quant = w_qmodule(w)
        ctx.input = (x, s_input_quant, w, s_weight_quant)
        
        output = torch.matmul(input_quant, weight_quant.t())

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)

        s_o = s_weight_quant * s_input_quant 

        return output.view(*ctx.reshape_3D_size[:-1], -1) * s_o


    @staticmethod
    def backward(ctx, g_3D):
        if ctx.device_id == 0 and ctx.iteration is not None:
            if ctx.iteration % 400 == 0 and ctx.layer_info is not None:
                probe(g_3D, block_num=ctx.block_num, layer=ctx.layer_info + 'X_grad_before', epoch=ctx.epoch, iteration=ctx.iteration)
        
        g_2D = g_3D.reshape(-1, g_3D.size(-1)) #reshape to 2D
        grad_X = grad_W = grad_bias = None 
        
        q_x, s_x, q_w, s_w = ctx.input

        #since the mixed precision mode, the gradient flows in fp16
        q_x = q_x.half() 
        q_w = q_w.half()

        if ctx.device_id == 0 and ctx.iteration is not None:
            if ctx.iteration % 400 == 0 and ctx.layer_info is not None:
                probe(q_w, block_num=ctx.block_num, layer=ctx.layer_info + 'weight', epoch=ctx.epoch, iteration=ctx.iteration)

        if ctx.device_id == 0 and ctx.iteration is not None:
            if ctx.iteration % 400 == 0 and ctx.layer_info is not None:
                probe(q_x, block_num=ctx.block_num, layer=ctx.layer_info + 'input_x', epoch=ctx.epoch, iteration=ctx.iteration)

        w_g_qmodule, a_g_qmodule = ctx.g_qmodule
        reshape_3D = ctx.reshape_3D_size

        a_g_2D_quant, a_s_g_2D_quant = a_g_qmodule(g_2D)
        grad_X = torch.matmul(a_g_2D_quant, q_w)
        grad_X = grad_X * a_s_g_2D_quant * s_w 
        grad_X = grad_X.view(reshape_3D[0],reshape_3D[1],-1)

        w_g_2D_quant, w_s_g_2D_quant = w_g_qmodule(g_2D)
        grad_W = torch.matmul(w_g_2D_quant.t(), q_x)
        grad_W = grad_W * w_s_g_2D_quant * s_x

        if ctx.has_bias:
            grad_bias = g_2D.sum(dim=0)
        else:
            grad_bias = None
        
        if ctx.device_id == 0 and ctx.iteration is not None:
            if ctx.iteration % 400 == 0 and ctx.layer_info is not None:
                probe(grad_X, block_num=ctx.block_num, layer=ctx.layer_info + 'X_grad_after', epoch=ctx.epoch, iteration=ctx.iteration)
                probe(grad_W, block_num=ctx.block_num, layer=ctx.layer_info + 'W_grad_after', epoch=ctx.epoch, iteration=ctx.iteration)
            
        return None, None, None, None, None, grad_X, grad_W, grad_bias, None, None, None, None

More additional code is as below.


class QuantizerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, N_bits, q_type, signed, symmetric, Qn, Qp, minimum_range):
        if N_bits is None:
            return x, torch.tensor(1.0)

        ctx.save_for_backward(x)
        ctx.N_bits = N_bits
        ctx.q_type = q_type
        ctx.signed = signed
        ctx.symmetric = symmetric
        ctx.Qn = Qn
        ctx.Qp = Qp
        ctx.minimum_range = minimum_range

        
        if symmetric:
            if q_type == 'per_tensor':
                max_x = x.abs().max().detach()
            elif q_type == 'per_token':
                max_x = x.abs().amax(dim=-1, keepdim=True).detach()
            elif q_type == 'per_channel':
                max_x = x.abs().amax(dim=0, keepdim=True).detach()
            scale = max_x / Qp
            x_q = (x / scale).round()
            
        else:
            if q_type == 'per_tensor':
                min_x = x.min().detach()
                max_x = x.max().detach()
            elif q_type == 'per_token':
                min_x = x.min(dim=-1, keepdim=True).values.detach()
                max_x = x.max(dim=-1, keepdim=True).values.detach()
            elif q_type == 'per_channel':
                min_x = x.min(dim=0, keepdim=True).values.detach()
                max_x = x.max(dim=0, keepdim=True).values.detach()
            range_x = (max_x - min_x).clamp(min=minimum_range)
            scale = range_x / (Qp - Qn)
            zero_point = torch.round((min_x / scale) - Qn)
            x_q = ((x / scale) + zero_point).round().clamp(Qn, Qp)
        
        ctx.scale = scale
        return x_q , scale

    @staticmethod
    def backward(ctx, grad_output, grad_scale):
        x, = ctx.saved_tensors
        scale = ctx.scale

        grad_x = grad_output / scale
        grad_scale = None
        
        return grad_x, None, None, None, None, None, None, None


class Quantizer(nn.Module):
    def __init__(self, N_bits, type="per_tensor", signed=True, symmetric=True):
        super(Quantizer, self).__init__()
        self.N_bits = N_bits
        self.signed = signed
        self.symmetric = symmetric
        self.q_type = type

        if self.N_bits is None:
            self.Qn = 0
            self.Qp = 0
            self.minimum_range = 1e-5
            return

        if self.signed:
            self.Qn = -2 ** (self.N_bits - 1)
            self.Qp = 2 ** (self.N_bits - 1) - 1
        else:
            self.Qn = 0
            self.Qp = 2 ** self.N_bits - 1

        self.minimum_range = 1e-5

    def forward(self, x):
        return QuantizerFunction.apply(
            x, self.N_bits, self.q_type, self.signed, self.symmetric, self.Qn, self.Qp, self.minimum_range
        )


Since I have to use quantization, I used round_pass function(same as state through estimator) to make gradients flow.

**Probing result) **
1. g_3D range: about 1~20
2. grad_W range: about 100~1000
3. grad_X range: about 1~20

weight gradient from param.grad range: the max value is 0.00xx…

Anyone does know about this?