Use traced regularization function during training

Hi, I am having doubt about using a traced function during training. So I made this regularization function that is quite computationally expensive and slows down training (despite vectorizing it). I tried optimizing it using JIT so I trace it and use it during training. Is it safe to use that way even if JIT is more for inference?

Scripting the model is not specific to inference use cases and training should work just fine.
Depending on the used backend less aggressive optimizations might be used (since e.g. intermediate activations might be stored or recomputed), but the training should not break.

1 Like

Thank you for the response! I am not sure I fully understand the internal dynamics of Pytorch during training, but what you said implies that autograd operates on on traced/scripted functions too?

Yes, Autograd will still work and will calculate the same gradients (up to floating point precision).
You can run a quick check as seen in this small example:

device = "cuda"
model = nn.Sequential(
    nn.Linear(10, 10),
    nn.ReLU(),
    nn.Linear(10, 10)
).to(device)

x = torch.randn(1, 10, device=device)

# reference
out = model(x)
out.mean().backward()

# store a gradient as the reference value
g0 = model[0].weight.grad.clone()
model.zero_grad()

# script
model = torch.jit.script(model)

# run a few iterations to allow JIT optimizations
for _ in range(5):
    model.zero_grad()
    out = model(x)
    out.mean().backward()
    
g1 = model[0].weight.grad.clone()

print((g0 - g1).abs().max())
# tensor(0., device='cuda:0')

Make sure to execute the scripted model a few times to allow some JIT optimizations to kick in.

1 Like

Thank you very much for the quick response! It is very helpful :slight_smile: