Concatenate layer output with additional input data

Your approach is completely fine and you can use the view instead of unsqueezing via the None indexing.

@ptrblck Now I am getting this RuntimeError: Given transposed=1, weight of size [128, 2048, 4, 4], expected input[64, 134, 1, 1] to have 128 channels, but got 134 channels instead

This error is raised, since you’ve increased the number of channels by concatenating the latent tensor with the one-hot encoded tensor.
I’m not familiar with the use case and cannot comment, if this is the right way to feed the label tensor to the model, but the workaround would be to increase the in_channels of the conv layer to 134.

@ptrblck I already tried that and got the same error : RuntimeError: Given transposed=1, weight of size [134, 2048, 4, 4], expected input[64, 140, 1, 1] to have 134 channels, but got 140 channels instead

This link https://arxiv.org/pdf/1611.06355.pdf the paper mentions the concatenation but does not mention about one-hot vector and this link: https://arxiv.org/pdf/1702.01983.pdf the paper mentions about one-hot vector.

May you help me to get where I am going wrong?

Based on the new error message it seems you’ve concatenated the one-hot encoded tensor twice to the latent tensor (128+6+6=140).

@ptrblck
This is the specification:
nz=134, ngf=128 b_size=64
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 16, 4, 1, 0, bias=False),
# Generate batch of latent vectors

noise = torch.randn(b_size, nz, 1, 1, device=device)

So I realized while generating noise I should use value 128 instead of nz
Thanks for your help. I might knock your doorbell when I will move to the discriminator :sweat_smile:

@ptrblck

Because of your assistance, I manged the Generator side concatenation. Now for the Discriminator, As shown in the figure the concatenation of the tensor is after the first convolution layer.
I use the same logic for the first thread of this post.

Your forward method

def forward(self, image, data):
    x1 = self.cnn(image)
    x2 = data
    x = torch.cat((x1, x2), dim=1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x

My forward method

def forward(self, input,input_age):
    x=self.main_0(input)
    print(x.size())
    print(input_age.size())
    y=input_age
    input2=torch.cat((x,y),1)
    return self.main(input2)

Now during my training, I am getting the following error

 Starting Training Loop...
torch.Size([64, 16, 64, 64])
torch.Size([64, 6, 1, 1])

RuntimeError: Sizes of tensors must match except in dimension 2. Got 1 and 64

Why these are not concatenated even though both the tensor is of same dimension ?

Edit: I found out in .luh

 -- --==DISCRIMINATOR==--

local netD = nn.Sequential()

-- Need a parallel table to put different layers for X (conv layers) 
-- and Y (none) before joining both inputs together.
local pt = nn.ParallelTable()

-- Convolutions applied only on X input
local Xconv = nn.Sequential() 
-- input is nc x opt.fineSize x opt.fineSize
Xconv:add(SpatialConvolution(nc, ndf, 4, 4, 2, 2, 1, 1))
Xconv:add(nn.LeakyReLU(0.2, true))

fltMult = 1

-- Replicate Y to match convolutional filter dimensions
local Yrepl = nn.Sequential()

-- ny -> ny x opt.fineSize/2 (replicate 2nd dimension)
Yrepl:add(nn.Replicate(opt.fineSize/2,2,1))
-- ny x 8 -> ny x opt.fineSize/2 x opt.fineSize/2 (replicate 3rd dimension)
Yrepl:add(nn.Replicate(opt.fineSize/2,3,2))


-- Join X and Y
pt:add(Xconv)
pt:add(Yrepl)
netD:add(pt)
netD:add(nn.JoinTable(1,3)) 

how to make same logic in pytorch ?

To concatenate two tensors in a specific dimension called dim all other dimensions must have the same shape. I’m not completely sure what the Torch7 code does, but I assume it’s repeating the tensor in the spatial dimension, which could be done in PyTorch using:

y = y.expand(-1, -1, 64, 64)

This will create a tensor in the shape [64, 6, 64, 64], which can be concatenated in dim1 with x.
Note that expand doesn’t trigger a copy on the tensor and just manipulates the meta data. To trigger a copy you could use repeat instead or call .contiguous() on the tensor afterwards.

@ptrblck Thanks for your help. It works. But I did not understand " Note that expand doesn’t trigger a copy on the tensor and just manipulates the metadata. To trigger a copy you could use repeat instead or call .contiguous() on the tensor afterward."

May you provide some simple example.

Sure, here is an example:

x = torch.randn(2, 1, 1)
print(x.size(), x.stride())

# expand manipulates the metadata of the tensor (stride in this case)
x_expand = x.expand(-1, 3, 3)
print(x_expand.size(), x_expand.stride())
# .contiguous() triggers a memory copy
x_expand_cont = x_expand.contiguous()
print(x_expand_cont.size(), x_expand_cont.stride())

# .repeat() triggers a memory copy
x_repeat = x.repeat(1, 3, 3)
print(x_repeat.size(), x_repeat.stride())

As you can see, expand only changes the metadata and no copy will be applied.
If you are running into an error, where contiguous memory is needed, you would have to call contiguous() on the expanded tensor or use repeat instead.

Hi @ptrblck,
Thank you for your awesome contribution!

While concatenating layer output with additional data, Does it matter if the concatenation is done after passing though Relu activation or should we concatenate and then pass it to Relu.

case1:

class DeepQNetwork(nn.Module):
    def __init__(self, lr, input_dims, fc1_dims, fc2_dims, 
            n_actions):
        super(DeepQNetwork, self).__init__()
        self.input_dims = input_dims
        self.fc1_dims = 512
        self.fc2_dims = 512
        self.fc3_dims = 512
        self.fc4_dims = 512

        
        self.n_actions = n_actions
        self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
        self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
        self.fc3 = nn.Linear(512+6,self.fc3_dims)
        self.fc4 = nn.Linear(self.fc3_dims, self.fc4_dims)
        self.fc5 = nn.Linear(self.fc4_dims, self.n_actions)
        
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state,velocity): 
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = T.cat([x,velocity],dim = 1)
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        actions = self.fc5(x)

        return actions

Or Case2:

class DeepQNetwork(nn.Module):
    def __init__(self, lr, input_dims, fc1_dims, fc2_dims, 
            n_actions):
        super(DeepQNetwork, self).__init__()
        self.input_dims = input_dims
        self.fc1_dims = fc1_dims
        self.fc2_dims = fc2_dims
        self.fc3_dims = 512
        self.fc4_dims = 512

        
        self.n_actions = n_actions
        self.fc1 = nn.Linear(*self.input_dims, self.fc1_dims)
        self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
        self.fc3 = nn.Linear(518,self.fc3_dims)
        self.fc4 = nn.Linear(self.fc3_dims, self.fc4_dims)
        self.fc5 = nn.Linear(self.fc4_dims, self.n_actions)
        
        self.optimizer = optim.Adam(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()
        
        self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
        self.to(self.device)

    def forward(self, state,velocity): 
        x = F.relu(self.fc1(state))
        x = self.fc2(x)
        x = T.cat([x,velocity],dim = 1)
        x = F.relu(x)
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        actions = self.fc5(x)

        return actions

I don’t know, what the best approach would be and guess “it depends” on the use case as so often.
Anyway, let us know if you’ve experimented with both approaches and have some recommendations. :slight_smile:

Hi @ptrblck thank you very much for your helpful ideas. I am quite new with this topic and currently trying to concatenate two tensors x1, x2, where x1 is my transformed image tensor returned by my feature extractor and x2 represents additional data which I would like to add to it. Tensor x1 has the size [8, 1024, 13, 13]and tensor x2 the size [8, 3, 416, 416].

As expected I receive the following error message.

RuntimeError: Sizes of tensors must match except in dimension 1. 
Got 13 and 416 in dimension 2 (The offending index is 1)

How could I transform the tensors so that I am able to concatenate both?

Many thanks!

Hi Janina,

it depends a bit in which dimension you would like to concatenate the tensors.
Currently, the batch size is equal, but all other dimensions differ in size. While concatenating tensors, all but the dimension used to concatenate the tensors should be equal.
Assuming you would like to use torch.cat in dim1, you would have to make sure the spatial sizes are equal e.g. by using a pooling operation or a conv layer etc.
Assuming you would reduce the spatial size of the second tensor the result would then be [8, 1027, 13, 13].
Would this be your expected use case and if not, could you explain it a bit more?

thanks a lot! This was a good hint.

Thank you for providing this code @ptrblck! This is really helpful. I have a question that is theoretical and not a problem in coding (that I know of), so please ignore if it is not appropriate here. I have set up a similar network with ResNet:

class ResNetFeature(nn.Module):
    def __init__(self):
        super(ResNetFeature, self).__init__()
        self.cnn = models.resnet18(pretrained=False) # I've also tried using pretrained=True
        self.cnn.fc = nn.Linear(
            self.cnn.fc.in_features, 20)
        self.fc1 = nn.Linear(20 + num_features, 60)
        self.fc2 = nn.Linear(60, 100)
        self.fc3 = nn.Linear(100, 30)
        self.fc4 = nn.Linear(30, num_classes)
    #
    def forward(self, image, data):
        x1 = self.cnn(image)
        x2 = data
        x = torch.cat((x1, x2), dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

If I train using this module, I get lower accuracy than if I train with just ResNet-18 on the images (and don’t include my numeric features). Any ideas why? If the numeric features I’m using weren’t helpful in classification, shouldn’t the network just ignore them.

In theory your model might ignore them, but would also have to “learn” to ignore them which could at least delay the training.
Here is a small code snippet, which shows how the additional data using values in another range would block the training (it might still converge in more iterations):

class ResNetFeature(nn.Module):
    def __init__(self, num_classes, num_features):
        super(ResNetFeature, self).__init__()
        self.cnn = models.resnet18(pretrained=False) # I've also tried using pretrained=True
        self.cnn.fc = nn.Linear(
            self.cnn.fc.in_features, 20)
        self.fc1 = nn.Linear(20 + num_features, num_classes)
    #
    def forward(self, image, data):
        x1 = self.cnn(image)
        x2 = data
        x = torch.cat((x1, x2), dim=1)
        x = self.fc1(x)
        return x

# setup
torch.manual_seed(2809)
data = torch.cat((torch.randn(5, 3, 224, 224), torch.randn(5, 3, 224, 224)+1), 0)
target = torch.cat((torch.zeros(5,), torch.ones(5,)), 0).long()

# vanilla model
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

nb_epochs = 10
for epoch in range(nb_epochs):
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {}'.format(epoch, loss.item()))
print(torch.argmax(output, 1) == target)

# use additional input
model = ResNetFeature(2, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
data2 = torch.randn(10, 10) * 1000

nb_epochs = 20
for epoch in range(nb_epochs):
    optimizer.zero_grad()
    output = model(data, data2)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {}'.format(epoch, loss.item()))

print(torch.argmax(output, 1) == target)

The range difference might come from e.g. (batch)normalized outputs of the resnet, while your additional input values might have different ranges.
Normalizing might help, but I’m not aware of any extensive testing of the best approaches.

1 Like

Thank you for your detailed explanation and example @ptrblck! This makes perfect sense. I am able to achieve comparable accuracy, but it does require about 20X as many epochs to converge. Clearly I need to find more informative numeric features to include if I want this to work.

Hello @ptrblck,

Just wondering what are the implications if I decide to use (Transfer Learning - TL) in this suggestion you gave:

pretrained = True

Since we are modifying the final architecture of the CNN, it is not clear if using TL is a better choice in terms of the performance of the network.

Thanks.

1 Like

I can’t give a general answer, as it might depend on your actual use case, but you could also treat the pretrained models as a “well initialized” model for your fine-tuning use case.

1 Like