Canonical way to use jit.script w/ DataParallel

Have been looking at using jit.script to potentially speed up some operations during training time. But the documentation is a bit sparse.

I have the following questions:

  1. Does it matter if we call
    model = model.cuda()
    model = torch.jit.script(model)
    
    versus
    model = torch.jit.script(model)
    model = model.cuda()
    
    Fundamentally, I guess the question is does the jit process depend on the device of the module?
  2. Calling DataParallel on the jitted module. It seems like you cannot jit a DataParallel object. But are there any gotchas using DataParallel on a jitted module correct?
  3. Saving state dicts
    On a normal model located on the GPU (with DataParallel), I can call
    torch.save(net.module.cpu().state_dict(), STATE_PATH)
    net.cuda()
    
    Is this correct for jitted networks?
  4. Loading state dicts
    On a normal model located on the GPU (with DataParallel), I can call
    net.module.load_state_dict(torch.load(STATE_PATH, map_location=str(net.device)))
    
    Is this still correct? Does the jit process introduce any subtleties that require caution?