Torch.cat doesn't notice dim=1

I am confused by the error I am receiving; it seems that torch is ignoring dim.

Output:
torch.Size([100, 10]) torch.Size([100, 32])

Traceback:

     10     print(y_sample.shape, z_sample.shape)
---> 11     x_dist = ae.yz_dec(torch.cat([y_sample, z_sample]), dim=1)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 10 and 32 in dimension 1 at ../aten/src/TH/generic/THTensor.cpp:612

For comparison, this piece of code runs without any errors and outputs torch.Size([100, 42]) .

a, b = torch.ones(100, 10), torch.ones(100, 32)
c = torch.cat([a, b], dim=1)
print(a.shape, b.shape, c.shape)

As opposed to this one,

a, b = torch.ones(100, 10), torch.ones(100, 32)
c = torch.cat([a, b], dim=0)
print(a.shape, b.shape, c.shape)

which encounters the same runtime error as the code above.

RuntimeError                              Traceback (most recent call last)
<ipython-input-12-d7fbede0f75e> in <module>
      1 a, b = torch.ones(100, 10), torch.ones(100, 32)
----> 2 c = torch.cat([a, b], dim=0)
      3 
      4 print(a.shape, b.shape, c.shape)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 10 and 32 in dimension 1 at ../aten/src/TH/generic/THTensor.cpp:612

I am on Python 3.7.5 (default, Nov 1 2019, 02:16:32) [Clang 11.0.0 (clang-1100.0.33.8)] and PyTorch 1.4.0.

I think this is expected. When you concatenate dimention 0, all the other ones must have the same size.
When you cat on dim 1 it works because the other dimension is 100 for both inputs.

1 Like

@albanD Indeed! But in the broken code above, the first snippet, I also concatenate on dim=1.

ETA: to make it more visible

I do torch.cat([y_sample, z_sample]), dim=1, and just before that print(y_sample.shape, z_sample.shape) outputs torch.Size([100, 10]) torch.Size([100, 32]).

Oh… it is a typo on my part: dim=1 is passed to the ae.yz_dec.

1 Like