Hello.
I wrote some custom backward function in transformer class.
And yes, it works!
But, when I tried to probe the weight gradient which is calculated in custom backward, the value is different from param.grad in same layer. (The norm of the gradient is much bigger than param.grad’s)
But the thing is that, model is trained well. (the loss reduces well, and test accuracy is same with the baseline code.)
My code is as below.
class Quantized_Linear(nn.Linear):
def __init__(self, weight_quantize_module: Quantizer, act_quantize_module: Quantizer, weight_grad_quantize_module: Quantizer, act_grad_quantize_module: Quantizer,
in_features, out_features, bias=True):
super(Quantized_Linear, self).__init__(in_features, out_features, bias=bias)
self.weight_quantize_module = weight_quantize_module
self.act_quantize_module = act_quantize_module
self.weight_grad_quantize_module = weight_grad_quantize_module
self.act_grad_quantize_module = act_grad_quantize_module
def forward(self, input, block_num, epoch, iteration, device_id, layer_info):
return _quantize_global.apply(block_num, epoch, iteration, device_id, layer_info, input, self.weight, self.bias, self.weight_quantize_module,
self.act_quantize_module, self.weight_grad_quantize_module, self.act_grad_quantize_module)
class _quantize_global(torch.autograd.Function):
@staticmethod
def forward(ctx, block_num, epoch, iteration, device_id, layer_info, x, w, bias=None, w_qmodule=None, a_qmodule=None, w_g_qmodule=None, a_g_qmodule=None):
#save for backward
ctx.block_num = block_num
ctx.iteration = iteration
ctx.layer_info = layer_info
ctx.g_qmodule = w_g_qmodule, a_g_qmodule
ctx.reshape_3D_size = x.size() # x as 3D
ctx.has_bias = bias is not None
ctx.epoch = epoch
ctx.device_id=device_id
x = x.view(-1, x.size(-1)) #reshape to 2D
input_quant, s_input_quant = a_qmodule(x)
weight_quant, s_weight_quant = w_qmodule(w)
ctx.input = (x, s_input_quant, w, s_weight_quant)
output = torch.matmul(input_quant, weight_quant.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
s_o = s_weight_quant * s_input_quant
return output.view(*ctx.reshape_3D_size[:-1], -1) * s_o
@staticmethod
def backward(ctx, g_3D):
if ctx.device_id == 0 and ctx.iteration is not None:
if ctx.iteration % 400 == 0 and ctx.layer_info is not None:
probe(g_3D, block_num=ctx.block_num, layer=ctx.layer_info + 'X_grad_before', epoch=ctx.epoch, iteration=ctx.iteration)
g_2D = g_3D.reshape(-1, g_3D.size(-1)) #reshape to 2D
grad_X = grad_W = grad_bias = None
q_x, s_x, q_w, s_w = ctx.input
#since the mixed precision mode, the gradient flows in fp16
q_x = q_x.half()
q_w = q_w.half()
if ctx.device_id == 0 and ctx.iteration is not None:
if ctx.iteration % 400 == 0 and ctx.layer_info is not None:
probe(q_w, block_num=ctx.block_num, layer=ctx.layer_info + 'weight', epoch=ctx.epoch, iteration=ctx.iteration)
if ctx.device_id == 0 and ctx.iteration is not None:
if ctx.iteration % 400 == 0 and ctx.layer_info is not None:
probe(q_x, block_num=ctx.block_num, layer=ctx.layer_info + 'input_x', epoch=ctx.epoch, iteration=ctx.iteration)
w_g_qmodule, a_g_qmodule = ctx.g_qmodule
reshape_3D = ctx.reshape_3D_size
a_g_2D_quant, a_s_g_2D_quant = a_g_qmodule(g_2D)
grad_X = torch.matmul(a_g_2D_quant, q_w)
grad_X = grad_X * a_s_g_2D_quant * s_w
grad_X = grad_X.view(reshape_3D[0],reshape_3D[1],-1)
w_g_2D_quant, w_s_g_2D_quant = w_g_qmodule(g_2D)
grad_W = torch.matmul(w_g_2D_quant.t(), q_x)
grad_W = grad_W * w_s_g_2D_quant * s_x
if ctx.has_bias:
grad_bias = g_2D.sum(dim=0)
else:
grad_bias = None
if ctx.device_id == 0 and ctx.iteration is not None:
if ctx.iteration % 400 == 0 and ctx.layer_info is not None:
probe(grad_X, block_num=ctx.block_num, layer=ctx.layer_info + 'X_grad_after', epoch=ctx.epoch, iteration=ctx.iteration)
probe(grad_W, block_num=ctx.block_num, layer=ctx.layer_info + 'W_grad_after', epoch=ctx.epoch, iteration=ctx.iteration)
return None, None, None, None, None, grad_X, grad_W, grad_bias, None, None, None, None
More additional code is as below.
class QuantizerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, N_bits, q_type, signed, symmetric, Qn, Qp, minimum_range):
if N_bits is None:
return x, torch.tensor(1.0)
ctx.save_for_backward(x)
ctx.N_bits = N_bits
ctx.q_type = q_type
ctx.signed = signed
ctx.symmetric = symmetric
ctx.Qn = Qn
ctx.Qp = Qp
ctx.minimum_range = minimum_range
if symmetric:
if q_type == 'per_tensor':
max_x = x.abs().max().detach()
elif q_type == 'per_token':
max_x = x.abs().amax(dim=-1, keepdim=True).detach()
elif q_type == 'per_channel':
max_x = x.abs().amax(dim=0, keepdim=True).detach()
scale = max_x / Qp
x_q = (x / scale).round()
else:
if q_type == 'per_tensor':
min_x = x.min().detach()
max_x = x.max().detach()
elif q_type == 'per_token':
min_x = x.min(dim=-1, keepdim=True).values.detach()
max_x = x.max(dim=-1, keepdim=True).values.detach()
elif q_type == 'per_channel':
min_x = x.min(dim=0, keepdim=True).values.detach()
max_x = x.max(dim=0, keepdim=True).values.detach()
range_x = (max_x - min_x).clamp(min=minimum_range)
scale = range_x / (Qp - Qn)
zero_point = torch.round((min_x / scale) - Qn)
x_q = ((x / scale) + zero_point).round().clamp(Qn, Qp)
ctx.scale = scale
return x_q , scale
@staticmethod
def backward(ctx, grad_output, grad_scale):
x, = ctx.saved_tensors
scale = ctx.scale
grad_x = grad_output / scale
grad_scale = None
return grad_x, None, None, None, None, None, None, None
class Quantizer(nn.Module):
def __init__(self, N_bits, type="per_tensor", signed=True, symmetric=True):
super(Quantizer, self).__init__()
self.N_bits = N_bits
self.signed = signed
self.symmetric = symmetric
self.q_type = type
if self.N_bits is None:
self.Qn = 0
self.Qp = 0
self.minimum_range = 1e-5
return
if self.signed:
self.Qn = -2 ** (self.N_bits - 1)
self.Qp = 2 ** (self.N_bits - 1) - 1
else:
self.Qn = 0
self.Qp = 2 ** self.N_bits - 1
self.minimum_range = 1e-5
def forward(self, x):
return QuantizerFunction.apply(
x, self.N_bits, self.q_type, self.signed, self.symmetric, self.Qn, self.Qp, self.minimum_range
)
Since I have to use quantization, I used round_pass function(same as state through estimator) to make gradients flow.
**Probing result) **
1. g_3D range: about 1~20
2. grad_W range: about 100~1000
3. grad_X range: about 1~20
weight gradient from param.grad range: the max value is 0.00xx…
Anyone does know about this?