How to select specific elements of a tensor?

I have tensor

y=
torch.arange(50).reshape(5,10)
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

and now I want to select elements of y that their indexes appears in s such that each row of s is correspond the the row of y that I want to selects the elements from. for example if s is

s = 
torch.arange(10).reshape(5,2)
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])

the output after slicing (y[s]) should be:

[[0,1],
[12,13],
[24,25],
[36,37],
[48,49]]
  • how can I do it?
  • I tried y[s], but it is not working, also I know that s should have same number of columns
  • will the gradient be reserved after slicing?

my naive approch would be this:

stack = []
for i in range(s.shape[0]):
    stack.append(y[:,s][:,i,:][i,:])
output_selec = torch.stack(stack)
output_selec
tensor([[ 0,  1],
        [12, 13],
        [24, 25],
        [36, 37],
        [48, 49]])

but 1) I am sure there should be a more pythonic/cleaner way (maybe @ptrblck can give me a suggestion), and 2) I am not sure if it will cause a disconnection between the gradients from previous operations

As far as I remember, slicing tensors does cause the tensor to lose its grad_fn.

I see, any suggestion on how to slice and keep the grad_fn?

Does torch.gather work for you?

import torch
y = torch.arange(50).reshape(5,10)
s = torch.arange(10).reshape(5,2)
print(torch.gather(y, 1, s))
tensor([[ 0,  1],
        [12, 13],
        [24, 25],
        [36, 37],
        [48, 49]])
3 Likes

Apparently, it does, thanks.
@eqy would torch.gather safe for keeping the gradient?
I see in the torch.gather documentations that it has something called sparse_grad, I did not quite understand what does it mean to have gradient to be a sparse tensor.

We can directly test if backprop works with both slice and gather:

import torch

class SliceNet(torch.nn.Module):
    def __init__(self):
        super(SliceNet, self).__init__()
        self.linear1 = torch.nn.Linear(32, 64)
        self.sigmoid = torch.nn.Sigmoid()
        self.linear2 = torch.nn.Linear(32, 1, bias=True)

    def forward(self, x):
        x = self.linear1(x)
        x = self.sigmoid(x)
        x = x[:,:32]
        x = self.linear2(x)
        return self.sigmoid(x)

class GatherNet(torch.nn.Module):
    def __init__(self):
        super(GatherNet, self).__init__()
        self.linear1 = torch.nn.Linear(32, 64)
        self.sigmoid = torch.nn.Sigmoid()
        self.linear2 = torch.nn.Linear(32, 1, bias=True)
        batch_size = 2
        self.indices = torch.arange(0, 64, 2).repeat(batch_size, 1)

    def forward(self, x):
        x = self.linear1(x)
        x = self.sigmoid(x)
        x = torch.gather(x, 1, self.indices)
        x = self.linear2(x)
        return self.sigmoid(x)

def test_gather_gradient():
    print("testing gathernet, loss should go down")
    data = torch.randn(2, 32)
    label = torch.randn(2, 1).round()
    gathernet = GatherNet()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(gathernet.parameters(), 1e-4)
    for i in range(10000):
        out = gathernet(data)
        loss = criterion(out, label)
        if i % 1000 == 0:
            print(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def test_slice_gradient():
    print("testing slicenet, loss should go down")
    data = torch.randn(2, 32)
    label = torch.randn(2, 1).round()
    slicenet = SliceNet()
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(slicenet.parameters(), 1e-4)
    for i in range(10000):
        out = slicenet(data)
        loss = criterion(out, label)
        if i % 1000 == 0:
            print(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

test_gather_gradient()
test_slice_gradient()
testing gathernet, loss should go down
0.16012126207351685
0.13146084547042847
0.10953329503536224
0.09261573851108551
0.07940082252025604
0.06893321871757507
0.06052353233098984
0.053674302995204926
0.04802417382597923
0.043307825922966
testing slicenet, loss should go down
1.1417878866195679
0.950451672077179
0.828807532787323
0.7515791654586792
0.7005079984664917
0.6650912761688232
0.6394497156143188
0.6201967597007751
0.60529625415802
0.593470573425293

In general, I don’t think most indexing operations should be problematic for backprop. For example, could max pooling work if indexing was a problem?

1 Like

Minor correction: slicing doesn’t detach the output:

x = torch.randn(4, 4, requires_grad=True)
y = x[:3, :2]
y.mean().backward()
print(x.grad)
> tensor([[0.1667, 0.1667, 0.0000, 0.0000],
          [0.1667, 0.1667, 0.0000, 0.0000],
          [0.1667, 0.1667, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000]])
1 Like