How to detach() a rows of a tensor?

Is it possible to detach a row or few elements in a tensor?? I tried this

>>> a = torch.rand((5,3),requires_grad = True)
>>> a[2] = a[2].detach()
>>> a[2].requires_grad
True
>>> a.requires_grad
True

I don’t think it’s possible. Could you just split the tensor and use the parts:

a = torch.randn(5, 3, requires_grad=True)
b = a[2].detach()
a = a[[0, 1, 3, 4]]
b.requires_grad
a.requires_grad

or could you explain your use case a bit more?

I found that even though a[2].requires_grad = True, the gradient doesn’t go through a[2]. Maybe this is a bug? More generally, if you want detach() an arbitrary part of a tensor, you can repeat the tensor to 2 copies of it, apply detach() to the second copy and use torch.gather to the repeated tensor as a whole to obtain the desired tensor.

1 Like

Strictly speaking this is no “detach” but it this way you can just set them to 0 with hooks:

import torch

def modify_grad(x, inds):
    x[inds] = 0 
    return x

a = torch.ones((5, 3)).requires_grad_(True)

inds = [2, 4]  # rows which will have zero gradients
b = 2 * a
b.register_hook(lambda x: modify_grad(x, inds))

loss = torch.sum(b**2)
loss.backward()

print('a.grad', a.grad)

will print

a.grad tensor([[8., 8., 8.],
        [8., 8., 8.],
        [0., 0., 0.],
        [8., 8., 8.],
        [0., 0., 0.]])
2 Likes

Thanks. This is very helpful!