Resnet18 fx_qat to onnx

  from torch.ao.quantization import (
    get_default_qat_qconfig_mapping,
    QConfigMapping,
  )
  
  import copy
  import torch
  import torch.nn as nn
  import torch.optim as optim
  import torch.ao.quantization.quantize_fx as quantize_fx
  
  from resnet import resnet18
  from utils import prepare_dataloader, model_timeit
  from tqdm import tqdm
  
  
  def train(args):
      
      device = torch.device("cuda:%s" % args.device if torch.cuda.is_available() else "cpu")
  
      model = resnet18(num_classes=10)
      
      qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm")
      
      model.train()
      example_inputs = torch.randn((1, 3, 224, 224))
  
      model = quantize_fx.prepare_qat_fx(model, qconfig_mapping, example_inputs)
      
      train_dataloader, test_dataloader = prepare_dataloader(batch_size=args.batch_size)
  
      loss_fc = nn.CrossEntropyLoss()
      optimizer = optim.Adam(model.parameters(), lr=1e-2)
      
      model.to(device)
      
      for epoch in range(args.epoch):
          epoch_loss = 0.0
          for idx, (image, label) in enumerate(tqdm(train_dataloader, total=len(train_dataloader))):
              image, label = image.to(device), label.to(device)
              output = model(image)
              loss = loss_fc(output, label)
              
              optimizer.zero_grad()
              loss.backward()
              optimizer.step()
              
              epoch_loss += loss.item()
          
          print(f"Epoch: {epoch}, Epoch Loss: {epoch_loss}")
      
      model = quantize_fx.convert_to_reference_fx(model)

what should do next to export this model to run tensorrt or onnx?

we don’t really support export to onnx, you’d be better off following a guide or asking onnx when you run into an issue.