How to fix mat1 and mat2 shapes cannot be multiplied using different image between training and testin Linear layer

Hi!

I’m making a below codes like an encoder.

nn.ModuleList([
                           nn.Sequential(nn.Conv2d(1,  self.image_size, 3, padding=1),
                                         nn.LeakyReLU(),
                                         nn.Conv2d(self.image_size, self.image_size, 3, padding=1)),

                           nn.Sequential(nn.LeakyReLU(),
                                         nn.Conv2d(self.image_size,  self.image_size*2, 3, padding=1),
                                         nn.LeakyReLU(),
                                         nn.Conv2d(self.image_size*2, self.image_size*2, 3, padding=1, stride=2)),

                           nn.Sequential(nn.LeakyReLU(),
                                         nn.Conv2d(self.image_size*2, self.image_size*2, 3, padding=1, stride=2)),

                           nn.Sequential(nn.LeakyReLU(),
                                         nn.AvgPool2d(4),
                                         Flatten(),
                                         nn.Linear(2048, 512))])

I would like to input a image that size is (1,256,1100). However If I input it directory, CUDA out of memory happen. So I implement cropping and down-sampling that make s image size (1,64,64). I would like use original size image in test phase, but in last all-coupled layer that would make a problem with the difference in channels like below error.

mat1 and mat2 shapes cannot be multiplied  (1x60928 and 2048x512)

Does anyone have a tip or a solution?
Any help is greatly appreciated!

You would need to either adapt the input features of the last linear layer and set it to 60928 (which is quite large) or you could reduce the activation shape further by e.g. using adaptive pooling layers.
These layers will output a defined spatial shape.
E.g. for image_size = 3 this would work:

class Print(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        print(x.shape)
        return x


image_size = 3
modules = nn.ModuleList([
    nn.Sequential(nn.Conv2d(1,  image_size, 3, padding=1),
                  nn.LeakyReLU(),
                  nn.Conv2d(image_size, image_size, 3, padding=1)),
    nn.Sequential(nn.LeakyReLU(),
                  nn.Conv2d(image_size,  image_size*2, 3, padding=1),
                  nn.LeakyReLU(),
                  nn.Conv2d(image_size*2, image_size*2, 3, padding=1, stride=2)),
    nn.Sequential(nn.LeakyReLU(),
                  nn.Conv2d(image_size*2, image_size*2, 3, padding=1, stride=2)),
    nn.Sequential(nn.LeakyReLU(),
                  nn.AdaptiveAvgPool2d(32),
                  Print(),
                  nn.Flatten(),
                  Print(),
                  nn.Linear(32*32*image_size*2, 512))
])

x = torch.randn(1, 1, 256, 1100)

out = modules[0](x)
for module in modules[1:]:
    out = module(out)

The Print module allows you to check intermediate activation shapes.