Two similar networks, completely different training evolution

So, I’m trying to remove the MaxPool replaced by strides in Convolutions as suggested in many papers (including the Resnet paper.)

This is my net (the backbone, but it’s the only relevant piece.):


class Backbone(nn.Module):
   # just a set of convolutions

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=2, stride=1, padding=2) # stride=2 in version2
        self.batch1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=2, stride=1, padding=2) # stride=2 in v2
        self.batch2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=2, stride=1, padding=2) # ditto
        self.batch3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 128, kernel_size=2, stride=1, padding=2) # ditto
        self.batch4 = nn.BatchNorm2d(128)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=2, stride=1, padding=2) # stride 1 or 2
        self.batch5 = nn.BatchNorm2d(256)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=(3, 3))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.batch1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.conv4(x)
        x = self.batch4(x)
        x = self.relu(x)
        x = self.conv5(x)
        x = self.relu(x)
        return x

As you can see this is just a series of convolutions followed by batch normalization and in some cases max pool.

If I change this to strides, the loss locks at a 10X larger value and the model does not improve anymore.

Is there an obvious reason why this would be a mistake ?

I seem to have found the reason in this post deep learning - Pooling vs. stride for downsampling - Cross Validated

It seems that strides=2 won’t be selecting for max signal, and this in turn ends up slowing down the learning (I am not currently checking gradients as I am a beginner but probably the gradient was small even thought the loss was not that small ?)

Maybe some one else can shed light