Hello,
I’m trying to quantize a model in FP8 using torchAO
quantize_(model, Float8WeightOnlyConfig())
My weights are quantized in FP8. However, when I use the function torch.export.export() observe dequatization layers that converts my weights back into FP32 :
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_wq_weight: "f32[64, 64]", p_wk_weight: "f32[32, 64]", p_wv_weight: "f32[32, 64]", p_wo_weight: "f32[64, 64]", x: "f32[1, 1, 64]", freqs_cos: "f32[1, 4]", freqs_sin: "f32[1, 4]"):
# No stacktrace found for following nodes
access_subclass_inner_tensor_default_9: "f8e4m3fn[64, 64]" = torch.ops.export.access_subclass_inner_tensor.default(p_wq_weight, 'tensor_impl'); p_wq_weight = None
access_subclass_inner_tensor_default_16: "f8e4m3fn[64, 64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_9, 'float8_data')
access_subclass_inner_tensor_default_17: "f32[64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_9, 'scale'); access_subclass_inner_tensor_default_9 = None
access_subclass_inner_tensor_default_27: "f8e4m3fn[32, 64]" = torch.ops.export.access_subclass_inner_tensor.default(p_wk_weight, 'tensor_impl'); p_wk_weight = None
access_subclass_inner_tensor_default_34: "f8e4m3fn[32, 64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_27, 'float8_data')
access_subclass_inner_tensor_default_35: "f32[32]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_27, 'scale'); access_subclass_inner_tensor_default_27 = None
access_subclass_inner_tensor_default_45: "f8e4m3fn[32, 64]" = torch.ops.export.access_subclass_inner_tensor.default(p_wv_weight, 'tensor_impl'); p_wv_weight = None
access_subclass_inner_tensor_default_52: "f8e4m3fn[32, 64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_45, 'float8_data')
access_subclass_inner_tensor_default_53: "f32[32]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_45, 'scale'); access_subclass_inner_tensor_default_45 = None
access_subclass_inner_tensor_default_63: "f8e4m3fn[64, 64]" = torch.ops.export.access_subclass_inner_tensor.default(p_wo_weight, 'tensor_impl'); p_wo_weight = None
access_subclass_inner_tensor_default_70: "f8e4m3fn[64, 64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_63, 'float8_data')
access_subclass_inner_tensor_default_71: "f32[64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_63, 'scale'); access_subclass_inner_tensor_default_63 = None
# File: /home/user/deep_learning/FP8-quantization-torchAO-copy-local/model/llama2.py:129 in forward, code: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
dequantize_affine: "f32[64, 64]" = torch.ops.torchao.dequantize_affine.default(access_subclass_inner_tensor_default_16, [1, 64], access_subclass_inner_tensor_default_17, None, torch.float8_e4m3fn, -448, 448, 'NONE'); access_subclass_inner_tensor_default_16 = access_subclass_inner_tensor_default_17 = None
linear: "f32[1, 1, 64]" = torch.ops.aten.linear.default(x, dequantize_affine); dequantize_affine = None
dequantize_affine_1: "f32[32, 64]" = torch.ops.torchao.dequantize_affine.default(access_subclass_inner_tensor_default_34, [1, 64], access_subclass_inner_tensor_default_35, None, torch.float8_e4m3fn, -448, 448, 'NONE'); access_subclass_inner_tensor_default_34 = access_subclass_inner_tensor_default_35 = None
linear_1: "f32[1, 1, 32]" = torch.ops.aten.linear.default(x, dequantize_affine_1); dequantize_affine_1 = None
dequantize_affine_2: "f32[32, 64]" = torch.ops.torchao.dequantize_affine.default(access_subclass_inner_tensor_default_52, [1, 64], access_subclass_inner_tensor_default_53, None, torch.float8_e4m3fn, -448, 448, 'NONE'); access_subclass_inner_tensor_default_52 = access_subclass_inner_tensor_default_53 = None
linear_2: "f32[1, 1, 32]" = torch.ops.aten.linear.default(x, dequantize_affine_2); x = dequantize_affine_2 = None
# File: /home/user/deep_learning/FP8-quantization-torchAO-copy-local/model/llama2.py:130 in forward, code: xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
view: "f32[1, 1, 8, 8]" = torch.ops.aten.view.default(linear, [1, 1, 8, 8]); linear = None
Can anyone help me with these questions:
- Does the export function support TorchAO?
- Why do I observe the dequatization layers?
GPU: Tesla P100-PCIE-16GB