How to slice a 3D tensor?

If I have a tensor like x = torch.rand((3,4,8)) and I would like to slice ‘x’ in order to fit into
y = torch.rand((2,3,4,4)).

I am able to do slicing in 2D using torch.narrow or torch.select. But, I am really confused in the 3D.
I really need to know this because I want to split up a bunch of patches.

Thank you in advance.

1 Like

I believe I should be using torch.split()

x = torch.rand((3,4,8))
slices = torch.split(x, 4, 2)

print(len(slices))
> 2

print(slices[0].size())
> torch.Size([3, 4, 4])

This will nicely return tuples of 3D tensors.

PS. Just posting as it might be a help to one in need.

1 Like

Note that you don’t need to use narrow or select, but instead basic indexing will do it for you.
For example

a = torch.rand(2, 4, 8)
print(a[:, 1, :3])
print(a[1, 2])
print(a[..., 1])
7 Likes