Lets say I have a tensor of A of shape [batch_size X length X 1024]
I want to do the following :
for the i element in the batch i want to shift the (length) elements embedding by their position .
for example the vector A[0 , 0 , : ] should stay the same, and A[0 , 1 , :] should be shifted (or rolled) by 1 , and A[0 , 15 , :] should be shifted by 15.
this is for all the elements in the batch.
so far i did it with for loops, but its not efficient ,
Could you post your current approach with loops, as I’m not completely understanding the use case.
If you shift
A[0, 1] by 1 in dim1, would
A[0, 2] be overwritten or also shifted?
A = .. # A is of shape [batch_size x len , embedding] target_embeddings =  len = A.shape for i in range(batch_size): for index in range(len): tmp = A[i , index , :] target_embeddings.append(torch.roll(tmp , index , 0)) #do torch.stack thing to get a tensor like object back