RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x30720 and 409600x64)

I have this encoder architecture

class Encoder(nn.Module):
    def __init__(self, in_channel, img_width, hidden_dim, device, max_filters=512, num_layers=4, small_conv=False, norm_type = 'batch', num_groups=1, kernel_size=4, stride_size=2, padding_size=0, activation = nn.PReLU()):
        super(Encoder,self).__init__()
        self.nchannel    = in_channel
        self.hidden_dim  = hidden_dim
        self.img_width   = img_width
        self.device      = device
        self.enc_kernel  = kernel_size
        self.enc_stride  = stride_size
        self.enc_padding = padding_size
        self.res_kernel  = 3
        self.res_stride  = 1
        self.res_padding = 1
        self.activation  = activation
        ########################
        # ENCODER-CONVOLUTION LAYERS
        if small_conv:
            num_layers += 1
        channel_sizes = calculate_channel_sizes(
            self.nchannel, max_filters, num_layers
        )

        # Encoder
        encoder_layers = nn.ModuleList()
        # Encoder Convolutions
        for i, (in_channels, out_channels) in enumerate(channel_sizes):
            if small_conv and i == 0:
                # 1x1 Convolution
                encoder_layers.append(nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=self.enc_kernel,
                        stride=self.enc_stride,
                        padding=self.enc_padding,
                ))
            else:
                encoder_layers.append( nn.Conv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=self.enc_kernel,
                        stride=self.enc_stride,
                        padding=self.enc_padding,
                        bias=False,
                    ))
            # Batch Norm
            if norm_type == 'batch':
                encoder_layers.append(nn.BatchNorm2d(out_channels))
            elif norm_type == 'layer':
                encoder_layers.append(nn.GroupNorm(num_groups, out_channels ))

            # ReLU
            encoder_layers.append(self.activation)

            if (i==num_layers//2):
                #add a residual Layer
                encoder_layers.append(ResidualBlock(
                        out_channels,
                        self.res_kernel,
                        self.res_stride,
                        self.res_padding,
                        norm_type=norm_type,
                        nonlinearity=self.activation
                    ))

        # Flatten Encoder Output
        encoder_layers.append(nn.Flatten())

        self.encoder = nn.Sequential(*encoder_layers)
        

        # Calculate shape of the flattened image
        self.h_dim, self.h_image_dim = self.get_flattened_size(self.img_width)
        #linear layers
        layers = []
        layers.append(nn.Linear(self.h_dim, hidden_dim, bias=False))
        if norm_type == 'batch':
            layers.append(nn.BatchNorm1d(hidden_dim))
        elif norm_type == 'layer':
            layers.append(nn.LayerNorm(hidden_dim ))
        layers.append(self.activation)
        layers.append(nn.Linear(hidden_dim, hidden_dim, bias=False))
        if norm_type == 'batch':
            layers.append(nn.BatchNorm1d(hidden_dim))
        elif norm_type == 'layer':
            layers.append(nn.LayerNorm(hidden_dim ))
        layers.append(self.activation)
        self.linear_layers = nn.Sequential( *layers)
        
        self.to(device=self.device)

    def forward(self,X):
            # Encode (note ensure input tensor has the shape [batch_size, channels, height, width])
            
            h = self.encoder(X)
            print(f"computed hidden dimension {self.h_dim}, {h.shape} and {self.h_image_dim}, input shape {X.shape}, architecture {self.encoder} input channel {self.nchannel}, image width {self.img_width}")
            #print(f"encoder architecture {self.encoder} input {X.shape} hidden {h.shape}")
            # Get latent variables
            return self.linear_layers(h)
            
    def get_flattened_size( self, image_size ):
        for layer in self.encoder.modules():
            if isinstance(layer, nn.Conv2d):
                kernel_size = layer.kernel_size[0]
                stride = layer.stride[0]
                padding = layer.padding[0]
                filters = layer.out_channels
                image_size = calculate_layer_size(
                    image_size, kernel_size, stride, padding
                )
               
        return filters * image_size * image_size, image_size

here is how I called this class

def calculate_conv_params(input_size):
    height, width, channels = input_size
    if height > 100 or width > 100:
        kernel_size = 5
    elif height > 50 or width > 50:
        kernel_size = 4
    else:
        kernel_size = 3
   
    stride = 1  # To keep spatial dimensions same, stride should be 1

    padding = (kernel_size - 1) // 2
   
    return kernel_size, stride, padding

kernel, stride, padding = calculate_conv_params((input_width,input_width,input_channel))
self.base = Encoder(input_channel, input_width, self.hidden_size, device, max_filters=256, num_layers=3, kernel_size= kernel, stride_size=stride, padding_size=padding)

here is the printed architecture and the size of input and hidden layer for the linear layer

computed hidden dimension 409600, torch.Size([256, 30720]) and 40, input shape torch.Size([256, 40, 40, 3]), architecture Sequential(
  (0): Conv2d(40, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): PReLU(num_parameters=1)
  (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): PReLU(num_parameters=1)
  (6): ResidualBlock(
    (layers): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): PReLU(num_parameters=1)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): PReLU(num_parameters=1)
    )
  )
  (7): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (8): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): PReLU(num_parameters=1)
  (10): Flatten(start_dim=1, end_dim=-1)
) input channel 40, image width 40

the error message from this specific input while it worked for other images

actor_features = self.base(obs)
  File "/home/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lustre03/project/6057506/onpolicy/algorithms/utils/cnn.py", line 268, in forward
    return self.linear_layers(h)
  File "/home/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward
    input = module(input)
  File "/home/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x30720 and 409600x64)

Can you kindly suggest a way to fix this error?

If you are using inputs with a variable shape you would need to make sure the number of features is static before feeding the intermediate activation to the first linear layer. Usually adaptive pooling layers are used to create a static feature dimension.