I’d probably go for your topk call on a flattened tensor and unreaveling the indices as described here:
topk
Best regards
Thomas