Re-use of intermediate graph nodes

I want to structure my neural network code in such a way as to put “postprocessing” steps into the model itself, and utilize flags for the forward pass. So for a classification NN big_nn I could call:

a = big_nn(x, mode='logits') # returns [-2, 2.4, 0.3, -1.1]
b = big_nn(x, mode='prediction') # returns 1

This leads to a very elegant coding style, but unfortunately in PyTorch it makes my training loop very inefficient! The reason is that it includes code that looks like:

loss = softmax_cross_entropy(big_nn(x, mode='logits'), y).mean()
error = (big_nn(x, mode='prediction') == y).double().mean()

Internally, this runs the forward pass for big_nn twice & doubles the number of nodes in the graph, which is very expensive, and something I definitely want to avoid. This could of course be re-made efficient by doing

logits = big_nn(x, mode='logits')
loss = softmax_cross_entropy(logits, y).mean()
error = (argmax(logits, -1) == y).double().mean()

…but from a software design perspective, I don’t like having the logic for converting logits to decisions to live in the training code. It belongs inside the neural network class.

Is there a way to write efficient PyTorch code in this style, by forcing it to re-use intermediate computations?

One potential approach would be to calculate both paths internally and either return both (and discard the unwanted output) or to use the mode argument to define which output should be returned.

That’s what I’m doing now, but it still doesn’t help me much when the two calls are in different parts of my code. Call A needs output 1, and call B needs output 2; I don’t want to have to pass around stuff in order to do both calls. I just want to have the intermediate computations cached by default, until e.g. .backward() is called.

If there’s no way to do this currently, might it be worth adding as a feature? I think it would bring great efficiency gains to many workflows, and enable much more elegant code. Tensorflow 1.x was able to do this natively, and I used that ability all the time (which is how I developed this coding style).