after i QAT my model, i tried to save it with torch.jit.save
but it did not work:
Hi,
Would you have a small code snippet that reproduces the crash please?
if args.qat:
configure(args.log_dir, flush_secs=5)
model.qconfig = torch.quantization.get_default_qat_qconfig(‘fbgemm’)
torch.quantization.prepare_qat(model, inplace=True)
# print(qat_model.parameters())optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) best_loss = 100 for nepoch in range(args.epochs): model.to(device) train_one_epoch(nepoch, model, criterion, optimizer, data_loader, device) if nepoch > 90: # Freeze quantizer parameters model.apply(torch.quantization.disable_observer) if nepoch > 95: # Freeze batch norm mean and variance estimates model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) # Check the accuracy after each epoch model.to('cpu') quantized_model = torch.quantization.convert(model.eval(), inplace=False) quantized_model.eval() size = get_size_of_model(quantized_model) top1, top5, loss = evaluate(quantized_model, data_loader_test) if loss < best_loss: best_loss = loss # torch.jit.save(torch.jit.script(quantized_model), 'quantizated_best.pth') torch.save(quantized_model.state_dict(), 'quantizated_best.pkl') print( f'Epoch {nepoch} | top1: {top1} | top5: {top5} | loss {loss} | size {size}') log_value('Validating/Accuracy', top1, nepoch) log_value('Validating/Loss', loss, nepoch) torch.jit.save(torch.jit.script(quantized_model), 'quantizated_final.pth')
everything is ok except the process of ‘torch.jit.save()’
I already solve this problem, thank you !!!
there is empty nn.Squentional() in my module, which torch.quantized.convert can`t remove (observer) from it,
so I remove it manually, and it works
5 Likes
btw, we have a fix for empty sequential as well: https://github.com/pytorch/pytorch/pull/28384
1 Like