How to trace jit in evaluation mode using multi-gpu learned model?

I used multi-gpu to train my model, but when I wanted to use jit trace to get .pt model, an error occured.

Here’s my training codes.

train_loader, test_loader = get_dataloader(batch_size, root)
    gpus = list(range(torch.cuda.device_count()))
    se_resnet = nn.DataParallel(se_resnet50(num_classes=102),
                                device_ids=gpus)
    optimizer = optim.SGD(params=se_resnet.parameters(), lr=0.6 / 1024 * batch_size, momentum=0.9, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, 30, gamma=0.1)
    trainer = Trainer(se_resnet, optimizer, F.cross_entropy, save_dir="./102flower_model_weights")

Here’s my jit tracing code.

gpus = list(range(torch.cuda.device_count()))
model = nn.DataParallel(se_resnet50(num_classes=102),device_ids=gpus).cuda()
checkpoint = torch.load('102flower_model_weights/model_epoch_70.pth')
model.load_state_dict(checkpoint["weight"])#, strict=False)
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# save
traced_script_module.save("102flower_model_weights/model_epoch_70.pt")

Error info:

    traced_script_module.save("102flower_model_weights/model_epoch_70.pt")
RuntimeError:
could not export python function call Scatter. Remove calls to python functions before export.:

this is a known limitation of the tracing / jit at the moment – that it doesn’t support nn.DataParallel right now.

What you can do is to trace your model inside of DataParallel.

For example:

traced_script_module = torch.jit.trace(model.module, example)

@smth is this resolved? Or are we still tracing inside of DataParallel? Sorry to dredge this post up after so long

1 Like

We do not support passing a DataParallel module to torch.jit.trace(). As @smth mentioned, one alternative is to trace the inner module, then call DataParallel() on it.

2 Likes