Tensor.topk() so slow?

Hi, all
I found that the speed of function tensor.topk() is very slow when evaluating a model for semantic segmentation. Below is the code snippet. output is 2x5x297x817 and it takes about 3.6s to calculate top-1 on GPU. Is there any optimization for the code?

output_var = model(input_var)
output = output_var.data
start = time.time()
_, pred = output.topk(1, 1, True, True)
print("Top one Time :{}".format(time.time() - start))

Would top-1 not be equivalent to max?

Yes,top-1 is equivalent to max(), and max() is very fast. Sometimes I just want to get top-3, and topk() is very slow. Forward time of CNN is 0.27s, however topk() takes nearly 3s…

how did you solve this problem? Thanks.