Issue with PyTorch Autoencoder Linear Layer Shape(mat1 and mat2 shapes cannot be multiplied)

Hello dear community,

I’m currently working on building an autoencoder in PyTorch that takes an input of shape [batch_size, 1, 21] and aims to map it to a bottleneck dimension of 3 after passing through the encoder. The model architecture involves convolutional and linear layers, and I’m encountering a runtime error that I’m having trouble resolving.

Here is the architecture of my model and the relevant training module:

#Model parameters:
LAYERS = 6
KERNELS = [4, 4, 4, 4, 4, 4]
CHANNELS = [1, 32, 64, 128, 256, 512]
STRIDES = [1, 1, 1, 1, 1, 1]
LINEAR_DIM = 16384

labels = ["0","1","2"]

config = {"batch_size" : 32,
          "epochs": 3,
          "lr" : 5e-4}


class Encoder(nn.Module):
    
    def __init__(self, output_dim = 3, use_batchnorm=False, use_dropout=False):
        super(Encoder, self).__init__()

        #bottleneck dimentionality
        self.output_dim = output_dim

        #variables deciding if using dropout and batchnorm in model
        self.use_dropout = use_dropout
        self.use_batchnorm = use_batchnorm

        #convolutional layer hyper parameters
        self.layers = LAYERS
        self.kernels = KERNELS
        self.channels = CHANNELS
        self.strides = STRIDES
        self.conv = self.get_convs()
        
        #layers for latent space projection
        self.fc_dim = LINEAR_DIM
        self.linear = nn.Linear(self.fc_dim, self.output_dim)

    def get_convs(self):
        """
        generating convolutional layers based on model's
        hyper parameters
        """
        conv_layers = nn.Sequential()
    
        for i_layer in range(self.layers):
            #The input channel of the first layer is 1:
            if i_layer == 0: conv_layers.append(nn.Conv1d(in_channels = 1,
                                                          out_channels = self.channels[i_layer],
                                                          kernel_size = self.kernels[i_layer],
                                                          stride = self.strides[i_layer],
                                                          padding = 0))
            else: conv_layers.append(nn.Conv1d(in_channels = self.channels[i_layer - 1],
                                           out_channels = self.channels[i_layer],
                                           kernel_size = self.kernels[i_layer],
                                           stride = self.strides[i_layer],
                                           padding = 0))
            if self.use_batchnorm: conv_layers.append(nn.BatchNorm1d(self.channels[i_layer]))
                
            #Activation Function:
            conv_layers.append(nn.GELU())
            
            if self.use_dropout: conv_layers.append(nn.Dropout1d(0.15))
        
        return conv_layers
    
    def forward(self, x):       
      
        print("Shape before conv: ", x.shape)        
        x = self.conv(x)
        print("Shape after conv: ", x.shape)    
       
        x = self.linear(x)
        print("Shape after linear: ", x.shape)
        
        return x


class Decoder(nn.Module):
    
    def __init__(self, input_dim = 3, use_batchnorm=False, use_dropout=False):
        super(Decoder, self).__init__()

        #variables deciding if using dropout and batchnorm in model
        self.use_dropout = use_dropout
        self.use_batchnorm = use_batchnorm

        self.fc_dim = LINEAR_DIM
        self.input_dim = input_dim         
        
        #convolutional layer hyper parameters
        self.layers = LAYERS
        self.kernels = KERNELS
        self.channels = CHANNELS[::-1] #flips the channels dimensions 
        self.strides = STRIDES
        
        
        # In decoder, we first do fc project, then conv layers
        self.linear = nn.Linear(self.input_dim, self.fc_dim)
        self.conv =  self.get_convs()

        self.output = nn.Conv1d(self.channels[-1], 1, kernel_size=4, stride=1)
        

    def get_convs(self):
        """
        generating convolutional layers based on model's
        hyper parameters
        """
        conv_layers = nn.Sequential()
    
        for i_layer in range(self.layers):
            #The input channel of the first layer is 1:
            if i_layer == 0: conv_layers.append(nn.ConvTranspose1d(in_channels = self.channels[i_layer],
                                                                   out_channels = self.channels[i_layer],
                                                                   kernel_size = self.kernels[i_layer],
                                                                   stride = self.strides[i_layer],
                                                                   padding = 0,
                                                                   output_padding = 0))
                
            else: conv_layers.append(nn.ConvTranspose1d(in_channels = self.channels[i_layer - 1],
                                                        out_channels = self.channels[i_layer],
                                                        kernel_size = self.kernels[i_layer],
                                                        stride = self.strides[i_layer],
                                                        padding = 0,
                                                        output_padding = 0))
                
            if self.use_batchnorm and i_layer != self.layers - 1: conv_layers.append(nn.BatchNorm1d(self.channels[i_layer]))
                
            #Activation Function:
            conv_layers.append(nn.GELU())
            
            if self.use_dropout: conv_layers.append(nn.Dropout1d(0.15))
        
        return conv_layers
    
    def forward(self, x):
        x = self.linear(x)
        #reshape 3D tensor to 4D tensor
        x = x.reshape(x.shape[0], 128, 4, 4)
        x = self.conv(x)
        return self.output(x)


class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder(output_dim = 3, use_batchnorm = True, use_dropout = False)
        self.decoder = Decoder(input_dim = 3, use_batchnorm = True, use_dropout = False)
    
    def forward(self, x):
        return self.decoder(self.encoder(x))

#TRAINING
model = AutoEncoder()
model = model.to(DEVICE)

criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-5)

# For mixed precision training
scaler = torch.cuda.amp.GradScaler()
steps = 0 # tracking the training steps

def train(model, dataloader, criterion, optimizer, save_distrib = False):
    # steps is used to track training progress, purely for latent space plots
    global steps
    
    model.train()
    train_loss = 0.0    
   
    for i, batch in enumerate(dataloader):
        
        optimizer.zero_grad()
        
        x = batch[0].to(DEVICE)
        
        # Here we implement the mixed precision training
        with torch.cuda.amp.autocast():
            y_recons = model(x)
            loss = criterion(y_recons, x)
            
        train_loss = train_loss + loss.item()
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
          
        # Saving latent space plots
        if steps % 10 == 0 and save_distrib and steps <=400: plotting(steps)
        
        steps = steps + 1
        
        # remove unnecessary cache in CUDA memory
        torch.cuda.empty_cache()
        del x, y_recons
    
    #batch_bar.close()
    train_loss = train_loss / len(dataloader)
    
    return train_loss  

for i in range(config["epochs"]):
    
    curr_lr = float(optimizer.param_groups[0]["lr"])
    
    train_loss = train(model, train_loader, criterion, optimizer, save_distrib = True)
    
    print(f"Epoch {i+1}/{config['epochs']}\nTrain loss: {train_loss:.4f}\tlr: {curr_lr:.4f}")

The issue arises when I run the code, and I receive the following error message:

Shape before conv:  torch.Size([32, 1, 21])
Shape after conv:  torch.Size([32, 512, 3])
...
RuntimeError: mat1 and mat2 shapes cannot be multiplied (16384x3 and 16384x3)

After researching similar posts, I understand that the source of this issue lies in the linear layer. However, I’m confused because both mat1 and mat2 have dimensions of 16384x3.

The full code it’s here

I would greatly appreciate any insights or suggestions!

Which is wrong as you wild need to transpose one matrix. Depending if the desired output should be [16384, 16384] or [3, 3] you should transpose the second or the first matrix, respectively.

1 Like

So the issue is even deeper because I want the output to be [1,3], ie, I want the output of the 1d conv layers to be transformed to [batch_size,1,3] by the linear layer

In that case you would need to permute the input activation before applying the linear layer and permute it back afterwards:

x = torch.randn([32, 512, 3])
lin = nn.Linear(512, 1)
out = lin(x.permute(0, 2, 1))
out = out.permute(0, 2, 1)
print(out.shape)
# torch.Size([32, 1, 3])
1 Like

Thank you @ptrblck! Now I’m getting my data though all the layers:

Shape before conv:  torch.Size([32, 1, 21])
Shape after conv:  torch.Size([32, 512, 3])
Shape after first permutation:  torch.Size([32, 3, 512])
Shape after linear:  torch.Size([32, 3, 1])
Shape after second permutation:  torch.Size([32, 1, 3])