Concatenate a column to a tensor with different dimensions

Hi,

I have a tensor like this:

tensor([[[[0., 0., 0.],
          [0., 0., 0.]]],

        [[[0., 0., 0.],
          [0., 0., 0.]]]])

I want to concatenate a column of 1 to tensor without any loop like this:

tensor([[[[1., 0., 0., 0.],
          [1., 0., 0., 0.]]],

        [[[1., 0., 0., 0.],
          [1., 0., 0., 0.]]]])

I try this code but I receive an error:

import torch
temp = torch.ones((2, 1, 2, 1))
input = torch.zeros((2, 1, 2, 3))
print(torch.cat((temp, input), 0))

Do you have any idea?

Many thanks

I think print(torch.cat((temp, input), 3)) gives the desired results.

1 Like

Thanks, you are right. What is 3?

3 is the dimension on which you want to concatenate.

1 Like

Thanks a lot, based on what yo said, I wrote this code, but I do not know what are the numbers -1 and -2, do you have any idea about it?

import torch
input = torch.zeros((2, 1, 4, 2))
print(input)

temp = torch.ones((input.shape[0], input.shape[1], input.shape[2], 1))
#temp = torch.ones((2, 1, 4, 1))

input = torch.cat((temp, input), 3)
print(input)

input = torch.cat((input, temp), -1)
print(input)

temp = torch.ones((input.shape[0], input.shape[1], 1, input.shape[3]))

input = torch.cat((input, temp), -2)
print(input)

input = torch.cat((temp, input), -2)
print(input)

The result is this:

tensor([[[[1., 1., 1., 1.],
          [1., 0., 0., 1.],
          [1., 0., 0., 1.],
          [1., 0., 0., 1.],
          [1., 0., 0., 1.],
          [1., 1., 1., 1.]]],


        [[[1., 1., 1., 1.],
          [1., 0., 0., 1.],
          [1., 0., 0., 1.],
          [1., 0., 0., 1.],
          [1., 0., 0., 1.],
          [1., 1., 1., 1.]]]])

These negative dimension indexes are taken mod input.ndimension by torch.cat. In your first example, -1 = 3 mod(4), where input.ndimension() == 4. In your second example, -2 = 2 mod(4) for the same reason.

2 Likes

Many thanks. So, for every one who have the same problem, we could rewrite the code so simple as bellow:

import torch
input = torch.zeros((2, 1, 4, 2))
print(input)

temp = torch.ones((input.shape[0], input.shape[1], input.shape[2], 1))
#temp = torch.ones((2, 1, 4, 1))

input = torch.cat((temp, input), 3)
print(input)

input = torch.cat((input, temp), 3)
print(input)

temp = torch.ones((input.shape[0], input.shape[1], 1, input.shape[3]))

input = torch.cat((input, temp), 2)
print(input)

input = torch.cat((temp, input), 2)
print(input)

Thanks again

1 Like