IN shape:
[59, 80]
Desired OUT is a list of 2 tensors:
[ [59, 40], [59, 40] ]
How?
Tried:
torch.split(IN, 2, dim=1)
Got:
[ [59, 2, 80], [59, 2, 80] ]
What is wrong? JakieChanWTF.png
IN shape:
[59, 80]
Desired OUT is a list of 2 tensors:
[ [59, 40], [59, 40] ]
How?
Tried:
torch.split(IN, 2, dim=1)
Got:
[ [59, 2, 80], [59, 2, 80] ]
What is wrong? JakieChanWTF.png
No torch.split
takes “size” of chunk/chunks not how many chunks.
Do this:
a = torch.randn(50, 80) #tensor of size 50 x 80
b = torch.split(a, 40, dim=1) # it returns a tuple
b = list(b) # convert to list if you want
@svd3’s solution is right.
However, I would like to know, how you got the strange output of [59, 2, 80]
.
Could you provide a small code snippet reproducing this output?
I would like to make sure there is no bug in the split
function.
@ptrblck I somehow could NOT replicate the bug, even in my original code today. It was definitely very strange because I definitely tried split(a, 40, dim=1) as suggested before and tried all variations and could not get it to work and had to settle with torch.narrow
but as of today I changed the code to split(a, 40, dim=1) and split(a, 2, dim=1), both behaved as expected and no bug. I am not sure what went wrong yesterday but thanks for the helps. I’ll keep an eye out should it happen again.
edit: I was doing a twitch stream yesterday so I could retrieve some information about what trasnpired, but only in screenshots
if you want to look at the stream record the bug happens about 2:08:00 to 2:12:00 here:
I’ve skipped through the video and maybe there was some confusion.
Regarding your error ar 2:09:
You are printing some info about max_mus
and apply the chunk
function on inv_mus
.
print(max_mus[0])
print(max_mus.size())
return torch.chunk(inv_mus, 2, dim=1)
After calling this function, you are printing again the shapes of both chunks, which have a weird shape ([58, 2, 80]
).
I didn’t follow your code, but it seems you were comparing max_mus
with inv_mus
.
Could it be you mixed up these tensors?
The split
and chunk
functions do, what they are supposed to do.
a = torch.randn(56, 80)
b = torch.split(a, 40, dim=1)
for split in b:
print(split.size())
> torch.Size([56, 40])
> torch.Size([56, 40])
c = torch.chunk(a, 2, dim=1)
for chunk in c:
print(chunk.size())
> torch.Size([56, 40])
> torch.Size([56, 40])