a = torch.rand(1, 3, 4)
print(a.shape)
b = torch.rand(3, 4)
print(b.shape)
b = b.unsqueeze(0)
print(b.shape)
c = torch.cat([a, b], dim=0)
print(c.shape)
3 Likes
a = torch.rand(1, 3, 4)
print(a.shape)
b = torch.rand(3, 4)
print(b.shape)
b = b.unsqueeze(0)
print(b.shape)
c = torch.cat([a, b], dim=0)
print(c.shape)