JIT for simple linear models?

What can be expected from JIT for simple linear models?

It’s possible I’m using JIT trace_module wrong, but here’s my model:

model = nn.Sequential(
        nn.Linear(784, 200),
        nn.Linear(200, 200),
        nn.Linear(200, 10),
model = model.to(DEVICE)
model = torch.jit.trace_module(
        model, {'forward': torch.randn(BATCH_SIZE, 784, device=DEVICE)})

for inputs, targets in data:
  outputs = model(inputs)
  loss = loss_fn(outputs, targets)

A simple linear model for MNIST.

I have a few observations:

  1. From using Nsight systems, it’s quite clear that without JIT, the performance is quite bad. This is normal since the kernels are very small, so there’s nothing much for the CPU or GPU to do.
  2. I’d expect JIT to be able to fuse quite a few of these kernels though. However, when using trace as shown above, the situation is exactly the same as without JIT. How come? Am I using JIT correctly here? And what kernels can I expect to be fused for this (simple) model ?

Thanks for bringing this to our attention.

Supporting matmul + relu on CPU, among other types of fusion, in a few months is on our roadmap. Then we might look into GPU. We will let you know if it is close to being ready.


Thanks for the quick answer and yes, it would be good to know when this is ready :slight_smile: !