class Network(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 50, kernel_size=5, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.BatchNorm2d(50),
nn.Flatten(),
nn.Linear(9999, 450),
nn.ReLU(inplace=True),
nn.Linear(450, 10)
])
def forward(self, x):
return self.model(x)
import torchinfo
model = Network()
torchinfo.summary(model, (32,1,224,224))
just run it like this once and torch will give you error what shape it expects for the nn.Linear layer
or, knowing input shape, you can calculate output shape (I prefer padding=‘same’ for these situations)
class Network(nn.Module):
def __init__(self, input_shape=(224,224), conv_out_channels=50):
super().__init__()
shape_after_conv = input_shape[0]*input_shape[1]*conv_out_channels
self.model = nn.Sequential(
nn.Conv2d(1, conv_out_channels, kernel_size=5, stride=1, padding='same'),
nn.ReLU(inplace=True),
nn.BatchNorm2d(50),
nn.Flatten(),
nn.Linear(shape_after_conv, 450),
nn.ReLU(inplace=True),
nn.Linear(450, 10)
])
def forward(self, x):
return self.model(x)
it’s not like your data will have diferent shapes every day. Even traditional ‘default’ models like resnets have predefined input shape like 224x224 and tinkering with it is possible but there’s very little reason since it is pretrained for this exact shape