Custom untraditional simulation implementations in pytorch - any hope for performance?

Hey there!

I am currently implementing a biologically-plausible neural network simulation which essentially numerically solves an ODE under the hood. I tried to get it to a similar speed as training a pytorch model, but is seems very hard to do. I recently started to profile my implementations using bottleneck (https://pytorch.org/docs/stable/bottleneck.html) and found out that even a single matrix multiplication I perform takes longer than the whole set of operations to update an NN which one batch[1]. The difference is about 13ms for the matrix multiplication vs. 2.5 ms for the forward-backward pass in classical pytorch code (Quick test, I can go more into that). Is there any hope to implement fast gpu-based simulations in pytorch that do not require its classical setup? Another question would possibly be: Are the pytorch higher-order operations (optimizer.step(), loss.backward() etc.) implemented in pytorch lower-order operations such as mm, bmm etc.?

[1] I mean something simple as this
images = images.to(device)
labels = labels.to(device)
output = model(images)
optimizer.zero_grad()
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()

NNs with ODEs have been done successfully, e.g. https://github.com/rtqichen/torchdiffeq/
In general, there is nothing to keep you from using the lower-level functions in PyTorch or mixing them with others.
The optimizers are impemented using them, see the code in torch/optim/.
When benchmarking, one has to be careful about the asynchronous nature of things. If your results are too surprising, chances are that there is room for refinement of your measurement methodology.

Best regards

Thomas

I realized that possibly my comparison was a bit unfair. I was comparing my recurrent neural network with a feedforward neural network, so some of the computational complexity was vastly different (apart from the ode vs. no-dynamics comparison that is a bit unfair). I will compare again to a regular RNN in pytorch (non-LSTM). Am I correct that the nn.RNN is a regular vanilla RNN? In the description it says it is an Elman-RNN, which is not exactly true since it is without context neurons.