What can be expected from JIT for simple linear models?
It’s possible I’m using JIT trace_module wrong, but here’s my model:
model = nn.Sequential(
nn.Linear(784, 200),
nn.ReLU(),
nn.Linear(200, 200),
nn.ReLU(),
nn.Linear(200, 10),
)
model = model.to(DEVICE)
model = torch.jit.trace_module(
model, {'forward': torch.randn(BATCH_SIZE, 784, device=DEVICE)})
for inputs, targets in data:
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
optimizer.step()
A simple linear model for MNIST.
I have a few observations:
- From using Nsight systems, it’s quite clear that without JIT, the performance is quite bad. This is normal since the kernels are very small, so there’s nothing much for the CPU or GPU to do.
- I’d expect JIT to be able to fuse quite a few of these kernels though. However, when using
trace
as shown above, the situation is exactly the same as without JIT. How come? Am I using JIT correctly here? And what kernels can I expect to be fused for this (simple) model ?