I used multi-gpu to train my model, but when I wanted to use jit trace to get .pt
model, an error occured.
Here’s my training codes.
train_loader, test_loader = get_dataloader(batch_size, root)
gpus = list(range(torch.cuda.device_count()))
se_resnet = nn.DataParallel(se_resnet50(num_classes=102),
device_ids=gpus)
optimizer = optim.SGD(params=se_resnet.parameters(), lr=0.6 / 1024 * batch_size, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.1)
trainer = Trainer(se_resnet, optimizer, F.cross_entropy, save_dir="./102flower_model_weights")
Here’s my jit tracing code.
gpus = list(range(torch.cuda.device_count()))
model = nn.DataParallel(se_resnet50(num_classes=102),device_ids=gpus).cuda()
checkpoint = torch.load('102flower_model_weights/model_epoch_70.pth')
model.load_state_dict(checkpoint["weight"])#, strict=False)
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# save
traced_script_module.save("102flower_model_weights/model_epoch_70.pt")
Error info:
traced_script_module.save("102flower_model_weights/model_epoch_70.pt")
RuntimeError:
could not export python function call Scatter. Remove calls to python functions before export.: