Maybe it is a little late. It is possible to achieve this, For example:
x = torch.rand(2, 400, 1, 1)
y = torch.rand(2, 400, 2, 2)
z = torch.rand(2, 400, 4, 4)
x = x.view(x.size(0), -1)
y = y.view(y.size(0), -1)
z = z.view(z.size(0), -1)
out = torch.cat([x, y, z], dim=1)
out
's size will be [2, 8400], hope it can help you.