Trouble using indexing over the mini-batch

I am having trouble implementing this expression using indexing, and would like to know how to implement it without iterating over the mini-batch.

  • max_args has shape (batch_size, num_channels)
  • each chunk has shape (batch_size, x, y, y)
  • x.shape[2] == x.shape[3] == y
out = []
        for i, chunk in enumerate(chunks):
            temp = []
            for j in range(max_args.shape[0]):
                temp.append(chunk[j, max_args[j, i], :, :].view(1, 1, x.shape[2], x.shape[3]))
            out.append(torch.cat(temp, 0))
out = torch.cat(out, 1)