I want to structure my neural network code in such a way as to put “postprocessing” steps into the model itself, and utilize flags for the forward pass. So for a classification NN big_nn
I could call:
a = big_nn(x, mode='logits') # returns [-2, 2.4, 0.3, -1.1]
b = big_nn(x, mode='prediction') # returns 1
This leads to a very elegant coding style, but unfortunately in PyTorch it makes my training loop very inefficient! The reason is that it includes code that looks like:
loss = softmax_cross_entropy(big_nn(x, mode='logits'), y).mean()
error = (big_nn(x, mode='prediction') == y).double().mean()
Internally, this runs the forward pass for big_nn
twice & doubles the number of nodes in the graph, which is very expensive, and something I definitely want to avoid. This could of course be re-made efficient by doing
logits = big_nn(x, mode='logits')
loss = softmax_cross_entropy(logits, y).mean()
error = (argmax(logits, -1) == y).double().mean()
…but from a software design perspective, I don’t like having the logic for converting logits to decisions to live in the training code. It belongs inside the neural network class.
Is there a way to write efficient PyTorch code in this style, by forcing it to re-use intermediate computations?