Efficient batch index_select

Hi all,

I have tried two ways for batch index_select, but there are still some problems.Here are weight tensor and index:

W = torch.rand(40000, 1024)
index = torch.randint(0, 40000, size=(256, 2000))

256 is the batch size

I could achieve a batch index_select by following two ways:

 first way:   W[ index, :]

 second way:   torch.gather( W.expand(256,-1,-1) , 1 , index.unsqueeze(2).expand( -1 , -1, 1024) )

The both result tensors are same and with same size [256, 2000, 1024]

The first ways cost lots of memory because the result tensor does not shared the same memory with source tensor. The second way using gather cost lots of time.

Do anyone have a more efficient way to achieve batch index_select?

regardless of the method, you cannot have the resulting Tensor share the same memory as the source Tensor.
The reason is because you aren’t asking for a slice of the original tensor, but you’re asking for particular indicees of the original Tensor. Only slices can be viewed on the memory of the original Tensor.

Your first way is the efficient version of doing things.

Alternatively, do see if this is faster (probably not, but worth a try):

W[ index.view(-1), :].view(index.size(0), index.size(1), 1024)

Thanks, I also have tried this way, but the memory cost is same with W[ index, :]

Actually, W is a parameter tensor required gradient. W[ index, :] might have two elements from same source. Are these two elements will be updated independly in backpropagation?

Hi @smth, Will backprop work fine through a tensor output from torch.gather? Reason I became suspicious is because I found out that torch.gather on an expanded tensor (i.e., output of torch.expand) resulted in a larger storage area than the original tensor (due to the expansion I am guessing).

Tensor Y: Shape = (12, 15, 512), Storage = 368704 bytes
Tensor indices: Shape = (12, 15, 5, 512), Storage = 664 bytes
Y_expanded = Y.unsqueeze(2).expand(-1, -1, 15, -1)
Tensor Y_expanded: Shape = (12, 15, 15, 512), Storage = 368704 bytes

Up until this point, size of Y_expanded == size of Y

Y_selected = torch.gather(Y_expanded, dim=2, index=m_indices)
Tensor Y_selected: Shape = (12, 15, 5, 512), Storage = 1843264 bytes

However, after I select from Y_expanded, the storage size increases, implying a copy into new storage. Is there a way I can avoid the increase in size? Also, will backprop flow back through Y if I use this method?



There’s lots of people asking for this… any hope of getting it?

All we are waiting for a savior…