# How to split a tensor in half

IN shape:

[59, 80]

Desired OUT is a list of 2 tensors:

[ [59, 40], [59, 40] ]

How?

Tried:

torch.split(IN, 2, dim=1)

Got:

[ [59, 2, 80], [59, 2, 80] ]

What is wrong? JakieChanWTF.png

No `torch.split` takes “size” of chunk/chunks not how many chunks.
Do this:

``````a = torch.randn(50, 80) #tensor of size 50 x 80
b = torch.split(a, 40, dim=1) # it returns a tuple
b = list(b) # convert to list if you want
``````
3 Likes

@svd3’s solution is right.
However, I would like to know, how you got the strange output of `[59, 2, 80]`.
Could you provide a small code snippet reproducing this output?
I would like to make sure there is no bug in the `split` function.

@ptrblck I somehow could NOT replicate the bug, even in my original code today. It was definitely very strange because I definitely tried split(a, 40, dim=1) as suggested before and tried all variations and could not get it to work and had to settle with torch.narrow

but as of today I changed the code to split(a, 40, dim=1) and split(a, 2, dim=1), both behaved as expected and no bug. I am not sure what went wrong yesterday but thanks for the helps. I’ll keep an eye out should it happen again.

edit: I was doing a twitch stream yesterday so I could retrieve some information about what trasnpired, but only in screenshots

if you want to look at the stream record the bug happens about 2:08:00 to 2:12:00 here:

I’ve skipped through the video and maybe there was some confusion.

You are printing some info about `max_mus` and apply the `chunk` function on `inv_mus`.

``````print(max_mus[0])
print(max_mus.size())

``````

After calling this function, you are printing again the shapes of both chunks, which have a weird shape (`[58, 2, 80]`).

I didn’t follow your code, but it seems you were comparing `max_mus` with `inv_mus`.
Could it be you mixed up these tensors?

The `split` and `chunk` functions do, what they are supposed to do.

``````a = torch.randn(56, 80)
b = torch.split(a, 40, dim=1)

for split in b:
print(split.size())
> torch.Size([56, 40])
> torch.Size([56, 40])

c = torch.chunk(a, 2, dim=1)
for chunk in c:
print(chunk.size())
> torch.Size([56, 40])
> torch.Size([56, 40])
``````
4 Likes