I am trying to implement a neural network that produces a discrete output in the form of integers between zero and infinity. My initial thought was to use the torch.round() function, however, I have found that this approach is not suitable as the round function does not support backpropagation. I have also attempted the approach outlined in the link Torch.round() gradient with no success.
What is the simplest method for achieving a discrete output of any positive integer in a neural network?
What did you actually try? Did it fail with error messages? No gradients?
The core problem is that if you want backpropagation, discrete values are
not usefully differentiable, in that their gradients are zero (almost) everywhere.
You have two choices:
You can produce continuous values that tend to be clustered around your
desired discrete values and then figure out how to compute a loss function
from these “approximately-discrete” values (that is usefully differentiable).
You can produce discrete values, but approximate their zero gradients with
non-zero gradients that capture the relationship of your loss function to the
parameters that produce your discrete values in a way that leads to useful
backpropagation and training.
Thanks for your answer K.Frank!
I tried to implement a “personalised” version of the round function which is ignored during the backpropagation:
def forward(ctx, input):
ctx.input = input
def backward(ctx, grad_output):
grad_input = grad_output.clone()
Anyhow, the model seems to not be learning with that approach. Could you detail a little bit more about any of the solutions you mentioned and how to implement them?
Thanks so much!
Whether a given approach might work will depend very much on your
specific use case.
Here’s one hypothetical illustration:
Let’s say that your loss function is given by a look-up table with ten entries
and your model predicts a discrete integer that runs from
loss = look_up_table[prediction].
You won’t be able to backpropagate through the discrete-integer prediction.
But you could modify your model to predict a continuous value that ranges
9.0 and use your look-up table as an interpolation table so that
your loss is now a continuous function of your (continuous) prediction. You
will now have valid, useful gradients, backpropagation will work, and your
network might (or might not) train.