TorchScript seems to be focused on low latency inference. Can it also be used for model training? If so, are there some examples out there, e.g. training a ResNet on MNIST?
Recently, I tried JAX to train a simple MLP model. I saw 100x speed-up through JIT, but I do prefer pytorch in general.