I have a list of tensors of different shape. This each tensor in this list is used to select a subset of rows of a 2D tensor W
. Consider the following example :
l1 = [torch.tensor([0, 1]),
torch.tensor([0, 1]),
torch.tensor([0, 2, 3, 5, 6]),
torch.tensor([0, 2, 3, 5, 6]),
torch.tensor([0, 2, 3, 5]),
torch.tensor([0, 2, 3]),
torch.tensor([0, 2, 4]),
torch.tensor([0, 2, 4])]
l1
>>> [tensor([0, 1]),
tensor([0, 1]),
tensor([0, 2, 3, 5, 6]),
tensor([0, 2, 3, 5, 6]),
tensor([0, 2, 3, 5]),
tensor([0, 2, 3]),
tensor([0, 2, 4]),
tensor([0, 2, 4])]
I have a global 2D weight tensor of shape [7,300]
. For each tensor in the list above, I want to select rows corresponding to indices in those tensor. So, I’d use first tensor tensor([0, 1])
to select the 1st and 2nd row of 2D tensor W
. Similarly, I’d use third tensor (tensor([0, 2, 3, 5, 6])
) to select 1st, 3rd, 4th, 6th and 7th rows of weight matrix W
. I do this using a for loop as follows.
W = torch.randn((7, 300), dtype=torch.float)
W
>>> tensor([[-0.3830, -0.5174, -1.8962, ..., 0.9110, 1.4701, 0.4347],
[ 0.2234, -0.0493, -1.1374, ..., -0.5010, 0.0462, 0.6938],
[-0.1562, -0.1772, -0.9825, ..., -0.4177, -0.2688, 0.1964],
...,
[ 1.9325, -0.5216, 1.0538, ..., -1.1179, -0.1078, 1.0612],
[-0.9618, 0.6065, -1.7601, ..., 0.1251, -0.7518, 0.6883],
[-1.0139, 0.2360, 1.4346, ..., 0.0604, 0.4170, -1.3022]])
W.shape
>>> torch.Size([7, 300])
selected_rows = []
for tensor in l1:
selected_rows.append(W[tensor])
[row.shape for row in selected_rows]
>>> [torch.Size([2, 300]),
torch.Size([2, 300]),
torch.Size([5, 300]),
torch.Size([5, 300]),
torch.Size([4, 300]),
torch.Size([3, 300]),
torch.Size([3, 300]),
torch.Size([3, 300])]
As we can see, we get a list of tensors which are the selected rows of matrix W
according to indices provided in list l1
. However, I want to avoid the for loop here. We can use torch.stack
, which will convert the list l1
into a torch tensor and then we can use that tensor to index the weight matrix W
. However, I can not use it since shapes of the tensors in the list are different. I don’t know if there is another way to do this operation in pytorch without using for loop. If there is any, please let me know. Thanks in advane
If all the tensors had the same shape in the list like following, we could easily avoid the for loop :
l1 = torch.tensor([[0, 1], [0, 1]])
l1
>>> tensor([[0, 1],
[0, 1]])
selected_rows = W[l1]
selected_rows.shape
>>> torch.Size([2, 2, 300])