I am confused by the error I am receiving; it seems that torch is ignoring dim
.
Output:
torch.Size([100, 10]) torch.Size([100, 32])
Traceback:
10 print(y_sample.shape, z_sample.shape)
---> 11 x_dist = ae.yz_dec(torch.cat([y_sample, z_sample]), dim=1)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 10 and 32 in dimension 1 at ../aten/src/TH/generic/THTensor.cpp:612
For comparison, this piece of code runs without any errors and outputs torch.Size([100, 42])
.
a, b = torch.ones(100, 10), torch.ones(100, 32)
c = torch.cat([a, b], dim=1)
print(a.shape, b.shape, c.shape)
As opposed to this one,
a, b = torch.ones(100, 10), torch.ones(100, 32)
c = torch.cat([a, b], dim=0)
print(a.shape, b.shape, c.shape)
which encounters the same runtime error as the code above.
RuntimeError Traceback (most recent call last)
<ipython-input-12-d7fbede0f75e> in <module>
1 a, b = torch.ones(100, 10), torch.ones(100, 32)
----> 2 c = torch.cat([a, b], dim=0)
3
4 print(a.shape, b.shape, c.shape)
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 10 and 32 in dimension 1 at ../aten/src/TH/generic/THTensor.cpp:612
I am on Python 3.7.5 (default, Nov 1 2019, 02:16:32) [Clang 11.0.0 (clang-1100.0.33.8)] and PyTorch 1.4.0.