Differentiable round or early indexing method

I have a binary output LSTM method that I weigh based on an output between 0 and 1, numbers above 0.5 being “right” and below being wrong.

I’m trying to get the prediction to heavily favor early predictions and therefore need the index of the earliest occurrence of a value greater than 0.5.

I’m not sure how to approach this as round is non differentiable and I dont clearly see a way to do this without rounding/doing an argsort. Any suggestions?