Slicing Tensor in Pytorch

Suppose I have a tensor:

a = torch.randn(B,N,V)

I want to get the third column of the tensor A along axis V, in the format (B,N,3,1).

I could do this by:

a_slice = input[:,:,3,None]

Particularly, I worry that this element of my code may not be differentiable. Is this the case?

If so, is there a way of doing this with Torch functions so that I don’t run into issues?

Hi,

All the indexing ops are differentiable. So no need to worry about it !

1 Like