Does pytorch have a tensor split method that does interleaving subslicing?

E.g. my input tensor is tensor([0, 1, 2, 3, 4, 5, 6]) , to be split into 3 chunks, where the intended outputs are:

tensor([0, 3, 6])
tensor([1, 4])
tensor([2, 5])

The tensor methods chunk, split, tensor_split all split the tensor into contiguous sub-tensors. I didn’t find one to split in the interleaving way above. Do we have any?

I think you can get partially there with a combination of reshape and chunk (might be missing a squeeze depending on your use-case):

>>> t = torch.arange(9)
>>> t = t.reshape(3,3)
>>> t
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])
>>> t.chunk(3, dim=1)
(tensor([[0],
        [3],
        [6]]), tensor([[1],
        [4],
        [7]]), tensor([[2],
        [5],
        [8]]))
>>>