Suppose I have a tensor:
a = torch.randn(B,N,V)
I want to get the third column of the tensor
A along axis
V, in the format
I could do this by:
a_slice = input[:,:,3,None]
Particularly, I worry that this element of my code may not be differentiable. Is this the case?
If so, is there a way of doing this with Torch functions so that I don’t run into issues?