Get indexes of maximum in batched images

I have batched images of shape (batch_size, channels, height, width). I would like to get the location of the maximum of each image in each channel, so an array of the form (batch_size, channels,2).
I thought using torch.argmax, but it do not seem to work as I want. Is there any simple way to perform that?
Thanks in advance.

We make a fake image

import torch
img=torch.randint(-20,20,(20*2*3,)).reshape(3,2,5,4)
print(img)
tensor([[[[ -5, -20,  18, -12],
          [-10,   2,  -9,  19],
          [  4,   3,  18,   6],
          [ -6, -14,  10,   2],
          [ 18,   5,  10,   8]],

         [[ 19, -16,  -1, -11],
          [ 18,  -6,  -3,   7],
          [  7, -15, -12, -17],
          [-17,  -8,   1,  19],
          [-13,  19,  16,  -1]]],


        [[[ -9, -13,   7,  19],
          [  5,  -3,  10,   0],
          [ -3,   4,  -8, -17],
          [  0,  10, -18, -12],
          [ 13,   3,  17,  17]],

         [[ -3,  -5,  12, -16],
          [  3,   6,  -5,  -4],
          [ -8,  -1,   9, -11],
          [ 17,  16,   6,  -7],
          [-10,  -4, -11, -18]]],


        [[[ -9, -12,  15, -13],
          [-12,  -4,  -2, -17],
          [ -6,  -7, -19, -16],
          [  8,   6, -12,  15],
          [ -8,  -4,  18,   7]],

         [[ -6,  -2,  -9,   5],
          [ 11,   5,  -5,  19],
          [ 18,  -6, -20, -16],
          [ 13, -11, -16,  12],
          [-13,  -2, -12,   2]]]])

We reshape it as flatten keeping channels and batch. Then we call argmax over the flatten dim.

tmp = img.view(3,2,20)
indices = torch.argmax(tmp,dim=-1)
row = indices // 4
column = indices - 4*row
print(row,column)
print(indices)
print(indices.shape)

tensor([[1, 0],
        [0, 3],
        [4, 1]]) tensor([[3, 0],
        [3, 0],
        [2, 3]])
tensor([[ 7,  0],
        [ 3, 12],
        [18,  7]])
torch.Size([3, 2])

Just check

print(img[2,1])
print(indices[2,1])
print(row[2,1])
print(column[2,1])
tensor([[ -6,  -2,  -9,   5],
        [ 11,   5,  -5,  19],
        [ 18,  -6, -20, -16],
        [ 13, -11, -16,  12],
        [-13,  -2, -12,   2]])
tensor(7)
tensor(1)
tensor(3)

@JuanFMontesinos many thanks for your answer! It is what I needed. Here is a version without hardcoded dimensions, and including the last step to have the location tensor with shape (batch_size,channels,2):


import torch

torch.manual_seed(0)

batch_size, channels, height, width = 4, 3, 6, 7

img = torch.randint(-20,20,(batch_size, channels, height, width))

tmp = img.view(img.shape[0],img.shape[1],-1)
indices = torch.argmax(tmp,dim=-1)
row = indices // width
column = indices - width*row
locs = torch.cat([row.unsqueeze(-1),column.unsqueeze(-1)],dim=-1)

print(locs)

print(img[2,1])
print(row[2,1])
print(column[2,1])