from torch.ao.quantization import (
get_default_qat_qconfig_mapping,
QConfigMapping,
)
import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torch.ao.quantization.quantize_fx as quantize_fx
from resnet import resnet18
from utils import prepare_dataloader, model_timeit
from tqdm import tqdm
def train(args):
device = torch.device("cuda:%s" % args.device if torch.cuda.is_available() else "cpu")
model = resnet18(num_classes=10)
qconfig_mapping = get_default_qat_qconfig_mapping("fbgemm")
model.train()
example_inputs = torch.randn((1, 3, 224, 224))
model = quantize_fx.prepare_qat_fx(model, qconfig_mapping, example_inputs)
train_dataloader, test_dataloader = prepare_dataloader(batch_size=args.batch_size)
loss_fc = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
model.to(device)
for epoch in range(args.epoch):
epoch_loss = 0.0
for idx, (image, label) in enumerate(tqdm(train_dataloader, total=len(train_dataloader))):
image, label = image.to(device), label.to(device)
output = model(image)
loss = loss_fc(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print(f"Epoch: {epoch}, Epoch Loss: {epoch_loss}")
model = quantize_fx.convert_to_reference_fx(model)
what should do next to export this model to run tensorrt or onnx?