Hey there!
I am currently implementing a biologically-plausible neural network simulation which essentially numerically solves an ODE under the hood. I tried to get it to a similar speed as training a pytorch model, but is seems very hard to do. I recently started to profile my implementations using bottleneck (https://pytorch.org/docs/stable/bottleneck.html) and found out that even a single matrix multiplication I perform takes longer than the whole set of operations to update an NN which one batch[1]. The difference is about 13ms for the matrix multiplication vs. 2.5 ms for the forward-backward pass in classical pytorch code (Quick test, I can go more into that). Is there any hope to implement fast gpu-based simulations in pytorch that do not require its classical setup? Another question would possibly be: Are the pytorch higher-order operations (optimizer.step(), loss.backward() etc.) implemented in pytorch lower-order operations such as mm, bmm etc.?
[1] I mean something simple as this
images = images.to(device)
labels = labels.to(device)
output = model(images)
optimizer.zero_grad()
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()