How to split a tensor in half

IN shape:

[59, 80]

Desired OUT is a list of 2 tensors:

[ [59, 40], [59, 40] ]

How?

Tried:

torch.split(IN, 2, dim=1)

Got:

[ [59, 2, 80], [59, 2, 80] ]

What is wrong? JakieChanWTF.png

No torch.split takes “size” of chunk/chunks not how many chunks.
Do this:

a = torch.randn(50, 80) #tensor of size 50 x 80
b = torch.split(a, 40, dim=1) # it returns a tuple
b = list(b) # convert to list if you want
3 Likes

@svd3’s solution is right.
However, I would like to know, how you got the strange output of [59, 2, 80].
Could you provide a small code snippet reproducing this output?
I would like to make sure there is no bug in the split function.

@ptrblck I somehow could NOT replicate the bug, even in my original code today. It was definitely very strange because I definitely tried split(a, 40, dim=1) as suggested before and tried all variations and could not get it to work and had to settle with torch.narrow

but as of today I changed the code to split(a, 40, dim=1) and split(a, 2, dim=1), both behaved as expected and no bug. I am not sure what went wrong yesterday but thanks for the helps. I’ll keep an eye out should it happen again.

edit: I was doing a twitch stream yesterday so I could retrieve some information about what trasnpired, but only in screenshots

if you want to look at the stream record the bug happens about 2:08:00 to 2:12:00 here:

I’ve skipped through the video and maybe there was some confusion.

Regarding your error ar 2:09:
You are printing some info about max_mus and apply the chunk function on inv_mus.

print(max_mus[0])
print(max_mus.size())

return torch.chunk(inv_mus, 2, dim=1)

After calling this function, you are printing again the shapes of both chunks, which have a weird shape ([58, 2, 80]).

I didn’t follow your code, but it seems you were comparing max_mus with inv_mus.
Could it be you mixed up these tensors?

The split and chunk functions do, what they are supposed to do.

a = torch.randn(56, 80)
b = torch.split(a, 40, dim=1)

for split in b:
    print(split.size())
> torch.Size([56, 40])
> torch.Size([56, 40])

c = torch.chunk(a, 2, dim=1)
for chunk in c:
    print(chunk.size())
> torch.Size([56, 40])
> torch.Size([56, 40])
4 Likes