Convolutional autoencoder dimensions error

Hi,

I’m trying to adapt the architecture from here to run on 3D volumes of size 182 x 218 x 182 (a.k.a. more channels than the standard RGB, and uneven height-width ratio).

Here are my Encoder, Decoder and Autoencoder:

class Encoder(nn.Module):
    
    def __init__(self, 
                 num_input_channels : int, 
                 base_channel_size : int, 
                 latent_dim : int, 
                 act_fn : object = nn.GELU):
        """
        Inputs: 
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten() , # Image grid to single feature vector
            nn.Linear(2*16*c_hid, latent_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    
    def __init__(self, 
                 num_input_channels : int, 
                 base_channel_size : int, 
                 latent_dim : int, 
                 act_fn : object = nn.GELU):
        """
        Inputs: 
            - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            #nn.Linear(latent_dim, 2*32*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=(2,2)), # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=(5,5)), # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=(6,5)), # 16x16 => 32x32
            nn.ReLU() # 
        )
    
    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

class Autoencoder(pl.LightningModule):
    
    def __init__(self, 
                 base_channel_size: int, 
                 latent_dim: int, 
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder,
                 #num_input_channels: int = 3, 
                 num_input_channels: int = 182, 
                 #width: int = 32, 
                 width: int = 218, 
                 #height: int = 32):
                 height: int = 182):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters() 
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)
        
        #self.automatic_optimization = False
        
    def forward(self, x):
        """
        The forward function takes in an image and returns the reconstructed image
        """
        z = self.encoder(x)
        print('done encoder')
        x_hat = self.decoder(z)
        print('done decoder')
        return x_hat
    
    def _get_reconstruction_loss(self, batch):
        """
        Given a batch of images, this function returns the reconstruction loss (MSE in our case)
        """
        x, _ = batch # We do not need the labels
        x_hat = self.forward(x)
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1,2,3]).mean(dim=[0]) #.to('cpu')
        print(loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                         mode='min', 
                                                         factor=0.2, 
                                                         patience=20, 
                                                         min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}
        #return {"optimizer": optimizer, "monitor": "val_loss"}
    
    def training_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)                             
        self.log('train_loss', loss)
        return loss #.backward()
    
    def validation_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('val_loss', loss)
    
    def test_step(self, batch, batch_idx):
        loss = self._get_reconstruction_loss(batch)
        self.log('test_loss', loss)

The output of torchinfo’s summary seems right:

mo = Autoencoder(base_channel_size=32, latent_dim=64, num_input_channels=182) #, width=218, height=182)
summary(mo, (10, 182, 218, 182))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Autoencoder                              --                        --
├─Encoder: 1-1                           [10, 64]                  --
│    └─Sequential: 2-1                   [10, 64]                  --
│    │    └─Conv2d: 3-1                  [10, 32, 109, 91]         52,448
│    │    └─GELU: 3-2                    [10, 32, 109, 91]         --
│    │    └─Conv2d: 3-3                  [10, 32, 109, 91]         9,248
│    │    └─GELU: 3-4                    [10, 32, 109, 91]         --
│    │    └─Conv2d: 3-5                  [10, 64, 55, 46]          18,496
│    │    └─GELU: 3-6                    [10, 64, 55, 46]          --
│    │    └─Conv2d: 3-7                  [10, 64, 55, 46]          36,928
│    │    └─GELU: 3-8                    [10, 64, 55, 46]          --
│    │    └─Conv2d: 3-9                  [10, 64, 28, 23]          36,928
│    │    └─GELU: 3-10                   [10, 64, 28, 23]          --
│    │    └─Flatten: 3-11                [10, 41216]               --
│    │    └─Linear: 3-12                 [10, 64]                  65,600
├─Decoder: 1-2                           [10, 182, 218, 182]       --
│    └─Sequential: 2-2                   [10, 1024]                --
│    │    └─Linear: 3-13                 [10, 1024]                66,560
│    │    └─GELU: 3-14                   [10, 1024]                --
│    └─Sequential: 2-3                   [10, 182, 218, 182]       --
│    │    └─ConvTranspose2d: 3-15        [10, 64, 8, 8]            36,928
│    │    └─GELU: 3-16                   [10, 64, 8, 8]            --
│    │    └─Conv2d: 3-17                 [10, 64, 8, 8]            36,928
│    │    └─GELU: 3-18                   [10, 64, 8, 8]            --
│    │    └─ConvTranspose2d: 3-19        [10, 32, 37, 37]          18,464
│    │    └─GELU: 3-20                   [10, 32, 37, 37]          --
│    │    └─Conv2d: 3-21                 [10, 32, 37, 37]          9,248
│    │    └─GELU: 3-22                   [10, 32, 37, 37]          --
│    │    └─ConvTranspose2d: 3-23        [10, 182, 218, 182]       52,598
│    │    └─ReLU: 3-24                   [10, 182, 218, 182]       --
==========================================================================================
Total params: 440,374
Trainable params: 440,374
Non-trainable params: 0
Total mult-adds (G): 29.06
==========================================================================================
Input size (MB): 288.84
Forward/backward pass size (MB): 665.42
Params size (MB): 1.76
Estimated Total Size (MB): 956.03
==========================================================================================

…Yet, there is an error when running the model through Pytorch (Lightning), which seems to suggest that backpropagation is skipping the bottleneck layer: AddmmBackward expects the shape of the penultimate Encoder layer (41216 elements), whereas it is seeing the shape of the Decoder’s first layer output (1024 elements):

RuntimeError: Function AddmmBackward returned an invalid gradient at index 1 - got [50, 1024] but expected shape compatible with [50, 41216]

Any ideas on what I could be missing here? Thanks!

Are you seeing the same issue using random input data and the posted model definition without using Lightning or is it specific to PTL?

@ptrblck Yes, there is an equivalent issue using Pytorch (without Lightning):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-33-403ad95c5423> in <module>
     25         optimizer.zero_grad()
     26         # forward pass: compute predicted outputs by passing inputs to the model
---> 27         outputs = model(images)
     28         # calculate the loss
     29         loss = criterion(outputs, images)

/shared/.conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-18-b59389063a94> in forward(self, x)
     27         The forward function takes in an image and returns the reconstructed image
     28         """
---> 29         z = self.encoder(x)
     30         print('done encoder')
     31         x_hat = self.decoder(z)

/shared/.conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

<ipython-input-15-371f0fe0c211> in forward(self, x)
     31 
     32     def forward(self, x):
---> 33         return self.net(x)

/shared/.conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/shared/.conda/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
    137     def forward(self, input):
    138         for module in self:
--> 139             input = module(input)
    140         return input
    141 

/shared/.conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/shared/.conda/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)
     94 
     95     def forward(self, input: Tensor) -> Tensor:
---> 96         return F.linear(input, self.weight, self.bias)
     97 
     98     def extra_repr(self) -> str:

/shared/.conda/lib/python3.6/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1845     if has_torch_function_variadic(input, weight):
   1846         return handle_torch_function(linear, (input, weight), input, weight, bias=bias)
-> 1847     return torch._C._nn.linear(input, weight, bias)
   1848 
   1849 

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x41216 and 1024x64)

The new error doesn’t seem to be an equivalent error, since the forward pass is failing due to an expected shape mismatch in the Encoder.
Change:

nn.Linear(2*16*c_hid, latent_dim)

to

nn.Linear(41216, latent_dim)

and the model works for me:

num_input_channels, width, height = 182, 218, 182        
example_input_array = torch.zeros(10, num_input_channels, width, height)
model = Autoencoder(base_channel_size=32, latent_dim=64, num_input_channels=182) 

output = model(example_input_array)
output.mean().backward()

so that I cannot reproduce the initially reported shape mismatch in the backward pass.
I don’t know why your initial Lightning model didn’t report the shape mismatch in the forward pass.

Thanks a lot @ptrblck !

You are right, the error came from within the Encoder… The solution you proposed works, and it seems like the following might work as well, for variable c_hid sizes:

nn.Linear(2*644*c_hid, latent_dim)