I am trying to use torch.kthvalue function. While the code runs perfectly on CPU, I get the following error on GPU.
RuntimeError: Type FloatTensor doesn't implement stateless method kthvalue
Is this a missing PyTorch feature? If so, are there any workarounds?
Edit: I currently perform the kthvalue operation in CPU and move the results to GPU.