Efficient way to remove trailing NaN values

Is there an efficient way to remove trailing nan values from a tensor?

Given Input: [NaN, 1, 2, NaN, 4, NaN, NaN]
Desired Output: [NaN, 1, 2, NaN, 4]

In my specific case, the amount of values to be removed is likely to be <50 while the entire tensor has more than a thousand elements. So it might just be most efficient to iterate over it to find the right cutoff point, starting from the last element, and then do “output = input[:cutoff]”. I’m just curious if there is a more efficient PyTorch way.