Tensor slicing while preserving dimension

Could anyone explain this? Where can i find some explanations about list parameters in slicing? Thanks so much!

a = torch.tensor([
    [1, 2, 3, 4, 5], 
    [6, 7, 8, 8, 10], 
    [11, 12, 13, 14, 15]]
    )
for _ in range(4):
    x1 = a[:,2:3]
    x2 = a[:,[2]]
    assert torch.all(x1==x2)

    print(x1.storage().data_ptr(),end='\t')
    print(x2.storage().data_ptr())

The output is:
image

Indexing with a list or tensor or other iterable (aside from interpreting top-level tuples as multiple index dimensions), this puts you into the realm of “advanced indexing”.
In lieu of PyTorch documentation (aside from a C++ cheat sheet), you might consult the numpy documentation on indexing.

Best regards

Thomas

1 Like

thanks! super helpful!