I am having trouble implementing this expression using indexing, and would like to know how to implement it without iterating over the mini-batch.
- max_args has shape (batch_size, num_channels)
- each chunk has shape (batch_size, x, y, y)
- x.shape[2] == x.shape[3] == y
out = []
for i, chunk in enumerate(chunks):
temp = []
for j in range(max_args.shape[0]):
temp.append(chunk[j, max_args[j, i], :, :].view(1, 1, x.shape[2], x.shape[3]))
out.append(torch.cat(temp, 0))
out = torch.cat(out, 1)