Examples for training with JIT

TorchScript seems to be focused on low latency inference. Can it also be used for model training? If so, are there some examples out there, e.g. training a ResNet on MNIST?

Recently, I tried JAX to train a simple MLP model. I saw 100x speed-up through JIT, but I do prefer pytorch in general.

1 Like