Indexing using the results returned by sorted

I have two tensors of the same shape: (batch_size, nchannels, n),
I would like to sort one of the tensor, and then reorder the other tensor using the indices returned by sort.

Because currently, index_select only support vector, so I use some work around to make my stuff work:

a = torch.randn(2,3,4)
sorted_a, indices = a.sort(dim=1)

new_indices = indices * a.size(-1)
batch_add_on = torch.arange(0, a.size(0))[:, None, None].repeat(1, a.size(1), a.size(2)).long() * a.size(1) * a.size(2)
last_dim_add_on = torch.arange(0, a.size(-1))[None, None, :].repeat(a.size(0), a.size(1), 1).long()
new_indices = new_indices + batch_add_on + last_dim_add_on

b = a
sorted_b = torch.index_select(b.view(-1), dim=0, index=new_indices.view(-1)).view(a.size(0), a.size(1), a.size(2))

But in this way, when training using multiple GPUs, such index method would cause some problems, if I understand data parallelism correctly?

Do you have a better way to do such indexing, or is data parallelism smart enough?


Did you find a better solution?