Does export support quantized models with torchAo

Hello,
I’m trying to quantize a model in FP8 using torchAO

quantize_(model, Float8WeightOnlyConfig())

My weights are quantized in FP8. However, when I use the function torch.export.export() observe dequatization layers that converts my weights back into FP32 :

ExportedProgram:                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
    class GraphModule(torch.nn.Module):                                                                                                                                                                                                                                                                                                                                                                                                                                                          
        def forward(self, p_wq_weight: "f32[64, 64]", p_wk_weight: "f32[32, 64]", p_wv_weight: "f32[32, 64]", p_wo_weight: "f32[64, 64]", x: "f32[1, 1, 64]", freqs_cos: "f32[1, 4]", freqs_sin: "f32[1, 4]"):                                                                                                                                                                                                                                                                                   
            # No stacktrace found for following nodes                                                                                                                                                                                                                                                                                                                                                                                                                                            
            access_subclass_inner_tensor_default_9: "f8e4m3fn[64, 64]" = torch.ops.export.access_subclass_inner_tensor.default(p_wq_weight, 'tensor_impl');  p_wq_weight = None                                                                                                                                                                                                                                                                                                                  
            access_subclass_inner_tensor_default_16: "f8e4m3fn[64, 64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_9, 'float8_data')                                                                                                                                                                                                                                                                                                           
            access_subclass_inner_tensor_default_17: "f32[64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_9, 'scale');  access_subclass_inner_tensor_default_9 = None                                                                                                                                                                                                                                                                          
            access_subclass_inner_tensor_default_27: "f8e4m3fn[32, 64]" = torch.ops.export.access_subclass_inner_tensor.default(p_wk_weight, 'tensor_impl');  p_wk_weight = None                                                                                                                                                                                                                                                                                                                 
            access_subclass_inner_tensor_default_34: "f8e4m3fn[32, 64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_27, 'float8_data')                                                                   
            access_subclass_inner_tensor_default_35: "f32[32]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_27, 'scale');  access_subclass_inner_tensor_default_27 = None                                                      
            access_subclass_inner_tensor_default_45: "f8e4m3fn[32, 64]" = torch.ops.export.access_subclass_inner_tensor.default(p_wv_weight, 'tensor_impl');  p_wv_weight = None                                             
            access_subclass_inner_tensor_default_52: "f8e4m3fn[32, 64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_45, 'float8_data')                                                                                                     
            access_subclass_inner_tensor_default_53: "f32[32]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_45, 'scale');  access_subclass_inner_tensor_default_45 = None                                                                                                   
            access_subclass_inner_tensor_default_63: "f8e4m3fn[64, 64]" = torch.ops.export.access_subclass_inner_tensor.default(p_wo_weight, 'tensor_impl');  p_wo_weight = None                                                
            access_subclass_inner_tensor_default_70: "f8e4m3fn[64, 64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_63, 'float8_data')                                                                                                     
            access_subclass_inner_tensor_default_71: "f32[64]" = torch.ops.export.access_subclass_inner_tensor.default(access_subclass_inner_tensor_default_63, 'scale');  access_subclass_inner_tensor_default_63 = None                                                                                                   
                                                                                                                                                                                                                                                    
             # File: /home/user/deep_learning/FP8-quantization-torchAO-copy-local/model/llama2.py:129 in forward, code: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)                                                                                                                                                            
            dequantize_affine: "f32[64, 64]" = torch.ops.torchao.dequantize_affine.default(access_subclass_inner_tensor_default_16, [1, 64], access_subclass_inner_tensor_default_17, None, torch.float8_e4m3fn, -448, 448, 'NONE');  access_subclass_inner_tensor_default_16 = access_subclass_inner_tensor_default_17 = None                                                                                                                                                                   
            linear: "f32[1, 1, 64]" = torch.ops.aten.linear.default(x, dequantize_affine);  dequantize_affine = None                                                                                                                                                                                                                  
            dequantize_affine_1: "f32[32, 64]" = torch.ops.torchao.dequantize_affine.default(access_subclass_inner_tensor_default_34, [1, 64], access_subclass_inner_tensor_default_35, None, torch.float8_e4m3fn, -448, 448, 'NONE');  access_subclass_inner_tensor_default_34 = access_subclass_inner_tensor_default_35 = None                                                                                                                                                                 
            linear_1: "f32[1, 1, 32]" = torch.ops.aten.linear.default(x, dequantize_affine_1);  dequantize_affine_1 = None                                                                                                                      
            dequantize_affine_2: "f32[32, 64]" = torch.ops.torchao.dequantize_affine.default(access_subclass_inner_tensor_default_52, [1, 64], access_subclass_inner_tensor_default_53, None, torch.float8_e4m3fn, -448, 448, 'NONE');  access_subclass_inner_tensor_default_52 = access_subclass_inner_tensor_default_53 = None                                                                                                                                                                 
            linear_2: "f32[1, 1, 32]" = torch.ops.aten.linear.default(x, dequantize_affine_2);  x = dequantize_affine_2 = None                                                                                                                  
                                                                                                                        
             # File: /home/user/deep_learning/FP8-quantization-torchAO-copy-local/model/llama2.py:130 in forward, code: xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)                                                         
            view: "f32[1, 1, 8, 8]" = torch.ops.aten.view.default(linear, [1, 1, 8, 8]);  linear = None

Can anyone help me with these questions:

  • Does the export function support TorchAO?
  • Why do I observe the dequatization layers?

GPU: Tesla P100-PCIE-16GB

yeah float8 weight only quant is only doing quant/dequant currently, we typically use Float8DynamicActivationFloat8WeightConfig — torchao main documentation instead for speedup, export is supported by I think with the most recent code, you’ll probably see decomposed ops instead of single quantize/dequantize ops.

What is the goal here?