How to Reshape a 2D tensor into a 4D tensor?

Hi! I want to reshape a tensor of size [batch_size, c*h*w] = [24, 1152] into one of size [batch_size, c, h, w] = [24, 128,3,3] but I can’t figure out how to do it. I’ve already tried the .view function. The 2D tensor is the output of a linear layer and I want the 4D tensor to be the input of a nn.ConvTranspose2d. This is the code, I’m trying to build a convolutional autoencoder with fully-conected layers in the middle:

class ConvAutoenc(nn.Module):
    
    def __init__(self, hidden_size):
        super().__init__() 
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1), #out size= 14x14x32
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), #out size= 7x7x64
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=0), #out size= 3x3x64
            nn.ReLU(),
            nn.Flatten()
        ) #out size= 1152 = 3*3*128
            
        self.hidden = nn.Sequential(
            nn.Linear(1152, hidden_size),
            nn.Linear(hidden_size, 1152))
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=3,stride=2, padding=0), 
            nn.ReLU(),#out = 7x7x64
            nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=4, stride=2, padding=1),
            nn.ReLU(),#out = 14x14x32
            nn.ConvTranspose2d(in_channels=32,out_channels=1,kernel_size=4, stride=2, padding=1)
        )
        
    def forward(self, xb):
        out = self.encoder(xb) #out size= 1152 = 3*3*128
        out = self.hidden(out) #out size= 1152
        z = out.view(-1, (128, 3, 3))
        out = self.decoder(z)
        return out

When I run it, this error shows up:

Blockquote
“RuntimeError: Expected 4-dimensional input for 4-dimensional weight [128, 64, 3, 3], but got 2-dimensional input of size [24, 1152] instead”

I’d appreciate any hint. Thanks in advance

You can do

out.view(out.shape[0], 128, 3, 3)
1 Like

Thanks! It works now