Torch.split without explicitly indicating split_size_or_sections

If I have the following tensor:

tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])

I know I can split it into 5 chunks using torch.split(t, 5) . However, I’m interested in automatically determining the number of splits based on the number of brackets enclosing the data. In this case, there’s one main bracket encapsulating 5 inner brackets, so I want to split it into 5 chunks automatically without specifying split_size_or_sections . How can I achieve this in PyTorch?

Thank you in advance

I assume you want to split in dim0 into single slices?
If so,

x = torch.tensor([[0, 1],
                  [2, 3],
                  [4, 5],
                  [6, 7],
                  [8, 9]])

x.split(split_size=1)

will work as it will return 5 slices.

Oh so I figured it out, I can simply use t.size(dim=1) as the split_size_or_sections. But it doesn’t split the tensor, here is my code:

centlen = centroids.size(dim=1)
splitcentroids = torch.tensor_split(centroids, centlen, dim=0)

for x in splitcentroids:
         print('x', x)

and the output is something like this:

x tensor([[[ 1,2],
         [ 3,4],
         [ 5,6]]], device='cuda:0')
x tensor([], device='cuda:0', size=(0, 3, 2))
x tensor([], device='cuda:0', size=(0, 3, 2))

However, what I want is something like:


x tensor([1,2])
x tensor([3, 4])
x tensor ([5, 6])

Any advice for me? Thank you in advance

Hi @ptrblck , I tried it but when I did a loop like this

for x in splitcentroids:
         print('x', x)

it outputs the following

tensor([[[0, 1],
       [2, 3],
       [4, 5],
       [6, 7],
       [8, 9]]])

instead of:

x tensor([1,2])
x tensor([3, 4])
x tensor ([5, 6])

Oh I realized I was using the wrong dim. The solution worked, thanks for the help…