I am having trouble implementing this expression using indexing, and would like to know how to implement it without iterating over the mini-batch.
- max_args has shape (batch_size, num_channels)
- each chunk has shape (batch_size, x, y, y)
- x.shape == x.shape == y
out =  for i, chunk in enumerate(chunks): temp =  for j in range(max_args.shape): temp.append(chunk[j, max_args[j, i], :, :].view(1, 1, x.shape, x.shape)) out.append(torch.cat(temp, 0)) out = torch.cat(out, 1)