Looking at the
torchvision.models.resnet34 this is forward:
class ResNet(nn.Module): def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = x.reshape(x.size(0), -1) x = self.fc(x) return x
resnet34 could have been
If it was not for the reshape. Then manipulating it would have been more straightforward and we would not need to treat it differently.
resnet34 is just an example, but in general it would be nice to also have a simple
nn.module and use it instead of re-implemeting