Hi, I am having doubt about using a traced function during training. So I made this regularization function that is quite computationally expensive and slows down training (despite vectorizing it). I tried optimizing it using JIT so I trace it and use it during training. Is it safe to use that way even if JIT is more for inference?
Scripting the model is not specific to inference use cases and training should work just fine.
Depending on the used backend less aggressive optimizations might be used (since e.g. intermediate activations might be stored or recomputed), but the training should not break.
Thank you for the response! I am not sure I fully understand the internal dynamics of Pytorch during training, but what you said implies that autograd operates on on traced/scripted functions too?
Yes, Autograd will still work and will calculate the same gradients (up to floating point precision).
You can run a quick check as seen in this small example:
device = "cuda" model = nn.Sequential( nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10) ).to(device) x = torch.randn(1, 10, device=device) # reference out = model(x) out.mean().backward() # store a gradient as the reference value g0 = model.weight.grad.clone() model.zero_grad() # script model = torch.jit.script(model) # run a few iterations to allow JIT optimizations for _ in range(5): model.zero_grad() out = model(x) out.mean().backward() g1 = model.weight.grad.clone() print((g0 - g1).abs().max()) # tensor(0., device='cuda:0')
Make sure to execute the scripted model a few times to allow some JIT optimizations to kick in.
Thank you very much for the quick response! It is very helpful