Batchwise argmax

What is the best way to do a batchwise argmax on a tesnor? Let’s say I have a B * H * W tensor and I want to do an argmax on each H * W sample. I have to end up with a B * 2 tensor that has the argmax components of each sample stacked. Is a for loop torch.stack the only possible way?

Hi M!

Reshape your tensor of shape [B, H, W] into shape [B, H * W]
and call .argmax (dim = 1). Use the indices returned by argmax()
to set the per-sample max values to a sentinel value. Reshape your
tensor back to [B, H, W] and use torch.where (t == sentinel)
to get a tuple of length B of the per-sample “argmax components.”

You can stack this tuple into a tensor of shape [B, 2] if you want a
tensor for the final result.


K. Frank