Validation Loss Diverging - Custom Resnet

I’ve tried converting my Tensorflow code (a modified resnet) to Pytorch code, but my validation loss in pytorch diverges after the 3rd or 4th epoch. Here’s my model code -

class ModifiedRes(nn.Module):
    def __init__(self, input_channels = 3, activation="relu"):
        self.conv0 = nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=7)
        self.activation = nonlinearity(activation)
        self.bn0 = nn.BatchNorm2d(num_features = 32)

        self.bl1 = ResBlock1(32, 16, 32, activation)

        self.bl2_1 = ResBlock2(32, 32, 64, activation)
        self.bl2_2 = ResBlock1(64, 32, 64, activation)

        self.bl3_1 = ResBlock2(64, 64, 128, activation)
        self.bl3_2 = ResBlock1(128, 64, 128, activation)

        self.bl4_1 = ResBlock2(128, 128, 256, activation)
        self.bl4_2 = ResBlock1(256, 128, 256, activation)

        self.bl5_1 = ResBlock2(256, 256, 512, activation)
        self.bl5_2 = ResBlock1(512, 256, 512, activation)

        self.dense = nn.Linear(in_features = 512, out_features = 1)

    def forward(self, x):
        out = self.conv0(x)
        out = self.activation(out)
        out = self.bn0(out)

        out = self.bl1(out)
        out = self.bl1(out)
        out = self.bl1(out)

        out = self.bl2_1(out)
        out = self.bl2_2(out)
        out = self.bl2_2(out)

        out = self.bl3_1(out)
        out = self.bl3_2(out)
        out = self.bl3_2(out)

        out = self.bl4_1(out)
        out = self.bl4_2(out)
        out = self.bl4_2(out)

        out = self.bl5_1(out)
        out = self.bl5_2(out)
        out = self.bl5_2(out)

        out = torch.mean(out, (2, 3))

        out = self.dense(out)

        return out

#loss is BCEWithLogitsLoss

The rest of my hyperparameters are the same between my Tensorflow Code and Pytorch code. As a sanity check, I checked to see if the number of trainable parameters match up across both implementations. In addition, to check the correctness of my data loader and training loop, I trained a Resnet50 from Torchvision (in which the validation loss decreased as expected.)

I’m stuck on what’s going wrong. Any help is appreciated!

You could check the default parameter initialization in both frameworks and make sure your ported PyTorch model uses the same as your TF model.
If the model is still diverging, you could verify the architecture by loading the TF parameters into the PyTorch model (or vice versa) and check the output for a fixed input tensor.