Have been looking at using jit.script to potentially speed up some operations during training time. But the documentation is a bit sparse.
I have the following questions:
- Does it matter if we call
versusmodel = model.cuda() model = torch.jit.script(model)
Fundamentally, I guess the question is does the jit process depend on the device of the module?model = torch.jit.script(model) model = model.cuda()
- Calling DataParallel on the jitted module. It seems like you cannot jit a DataParallel object. But are there any gotchas using DataParallel on a jitted module correct?
- Saving state dicts
On a normal model located on the GPU (with DataParallel), I can call
Is this correct for jitted networks?torch.save(net.module.cpu().state_dict(), STATE_PATH) net.cuda()
- Loading state dicts
On a normal model located on the GPU (with DataParallel), I can call
Is this still correct? Does the jit process introduce any subtleties that require caution?net.module.load_state_dict(torch.load(STATE_PATH, map_location=str(net.device)))