If I have a tensor like
x = torch.rand((3,4,8)) and I would like to slice ‘x’ in order to fit into
y = torch.rand((2,3,4,4)).
I am able to do slicing in 2D using
torch.select. But, I am really confused in the 3D.
I really need to know this because I want to split up a bunch of patches.
Thank you in advance.
I believe I should be using torch.split()
x = torch.rand((3,4,8))
slices = torch.split(x, 4, 2)
> torch.Size([3, 4, 4])
This will nicely return tuples of 3D tensors.
PS. Just posting as it might be a help to one in need.
Note that you don’t need to use
select, but instead basic indexing will do it for you.
a = torch.rand(2, 4, 8)
print(a[:, 1, :3])