Hi,

What is the best way to do a batchwise argmax on a tesnor? Let’s say I have a B * H * W tensor and I want to do an argmax on each H * W sample. I have to end up with a B * 2 tensor that has the argmax components of each sample stacked. Is a for loop torch.stack the only possible way?

Hi M!

Reshape your tensor of shape `[B, H, W]`

into shape `[B, H * W]`

and call `.argmax (dim = 1)`

. Use the indices returned by `argmax()`

to set the per-sample max values to a `sentinel`

value. Reshape your

tensor back to `[B, H, W]`

and use `torch.where (t == sentinel)`

to get a tuple of length `B`

of the per-sample “argmax components.”

You can stack this tuple into a tensor of shape `[B, 2]`

if you want a

tensor for the final result.

Best.

K. Frank