Only one element tensors can be converted

I cacth this error, please help me

The item() operation can only convert a single element to Python scalars as given in the error message, while the result of get_token_high_prob returns multiple elements.

x = torch.randn(2)
x.item() # error
> ValueError: only one element tensors can be converted to Python scalars

x = torch.randn(1)
x.item() # works

You would thus either have to make sure the mentioned function returns a single value or you would have to remove the item() operation and handle the result differently.