Selecting a set of rows from a 2D tensor

I have a large tensor of size [80000, 300]. I want to select a set of n rows from this 2D tensor. So the resultant tensor should have a shape [n, 300]. I looked at this similar topic, and it proposed a very simple solution of using torch.ByteTensor. However my problem is I can not initialise that large ByteTensor.

I tried create a ByteTensor for indexing the main 2D tensor as follows. I first create a zero tensor of size 80000, then I fill the positions of required rows with 1ā€™s. For example, If I want to select the 1st, 5th, and 7th rows, Iā€™d do the following.

ind = torch.zeros(80000)
ind[0] = 1
ind[4] = 1
ind[6] = 1

Then how do I obtain the ByteTensor out of this tensor? Simply doing torch.ByteTensor(ind) is raising error.

You could transform the tensor via ind = ind.byte(). However, if your number of indices is small, it should be beneficial to directly index the tensor via:

x = torch.randn(80, 3)
idx = torch.tensor([0, 4, 6])
ret = x[idx]
1 Like

Thank you so much @ptrblck