Suppose I have a tensor:
a = torch.randn(B,N,V)
I want to get the third column of the tensor A
along axis V
, in the format (B,N,3,1)
.
I could do this by:
a_slice = input[:,:,3,None]
Particularly, I worry that this element of my code may not be differentiable. Is this the case?
If so, is there a way of doing this with Torch functions so that I don’t run into issues?