Model conversion (Lua to PyTorch): problem with the validation

I have this model in Lua Torch that I wish to reproduce in PyTorch.

nn.Sequential {
  [input -> (1) -> (2) -> output]
  (1): nn.ParallelTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> output]
      |      (1): nn.ParallelTable {
      |        input
      |          |`-> (1): nn.Sequential {
      |          |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> (20) -> (21) -> (22) -> (23) -> (24) -> output]
      |          |      (1): nn.SpatialConvolution(3 -> 32, 7x7)
      |          |      (2): nn.ReLU
      |          |      (3): nn.SpatialBatchNormalization (4D) (32)
      |          |      (4): nn.Dropout(0.200000)
      |          |      (5): nn.SpatialConvolution(32 -> 32, 7x7)
      |          |      (6): nn.ReLU
      |          |      (7): nn.SpatialBatchNormalization (4D) (32)
      |          |      (8): nn.Dropout(0.500000)
      |          |      (9): nn.SpatialConvolution(32 -> 64, 7x7)
      |          |      (10): nn.ReLU
      |          |      (11): nn.SpatialBatchNormalization (4D) (64)
      |          |      (12): nn.Dropout(0.500000)
      |          |      (13): nn.SpatialConvolution(64 -> 64, 7x7)
      |          |      (14): nn.ReLU
      |          |      (15): nn.SpatialBatchNormalization (4D) (64)
      |          |      (16): nn.Dropout(0.500000)
      |          |      (17): nn.SpatialConvolution(64 -> 64, 7x7)
      |          |      (18): nn.ReLU
      |          |      (19): nn.SpatialBatchNormalization (4D) (64)
      |          |      (20): nn.Dropout(0.500000)
      |          |      (21): nn.SpatialConvolution(64 -> 64, 7x7)
      |          |      (22): nn.SpatialBatchNormalization (4D) (64)
      |          |      (23): nn.Transpose
      |          |      (24): nn.Reshape(1x64)
      |          |    }
      |           `-> (2): nn.Sequential {
      |                 [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> (20) -> (21) -> (22) -> (23) -> output]
      |                 (1): nn.SpatialConvolution(3 -> 32, 7x7)
      |                 (2): nn.ReLU
      |                 (3): nn.SpatialBatchNormalization (4D) (32)
      |                 (4): nn.Dropout(0.200000)
      |                 (5): nn.SpatialConvolution(32 -> 32, 7x7)
      |                 (6): nn.ReLU
      |                 (7): nn.SpatialBatchNormalization (4D) (32)
      |                 (8): nn.Dropout(0.500000)
      |                 (9): nn.SpatialConvolution(32 -> 64, 7x7)
      |                 (10): nn.ReLU
      |                 (11): nn.SpatialBatchNormalization (4D) (64)
      |                 (12): nn.Dropout(0.500000)
      |                 (13): nn.SpatialConvolution(64 -> 64, 7x7)
      |                 (14): nn.ReLU
      |                 (15): nn.SpatialBatchNormalization (4D) (64)
      |                 (16): nn.Dropout(0.500000)
      |                 (17): nn.SpatialConvolution(64 -> 64, 7x7)
      |                 (18): nn.ReLU
      |                 (19): nn.SpatialBatchNormalization (4D) (64)
      |                 (20): nn.Dropout(0.500000)
      |                 (21): nn.SpatialConvolution(64 -> 64, 7x7)
      |                 (22): nn.SpatialBatchNormalization (4D) (64)
      |                 (23): nn.Reshape(64x121)
      |               }
      |           ... -> output
      |      }
      |      (2): nn.MM
      |      (3): nn.Reshape(121)
      |      (4): nn.LogSoftMax
      |    }
       `-> (2): nn.Sequential {
             [input -> (1) -> (2) -> (3) -> (4) -> output]
             (1): nn.ParallelTable {
               input
                 |`-> (1): nn.Sequential {
                 |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> (20) -> (21) -> (22) -> (23) -> (24) -> output]
                 |      (1): nn.SpatialConvolution(3 -> 32, 7x7)
                 |      (2): nn.ReLU
                 |      (3): nn.SpatialBatchNormalization (4D) (32)
                 |      (4): nn.Dropout(0.200000)
                 |      (5): nn.SpatialConvolution(32 -> 32, 7x7)
                 |      (6): nn.ReLU
                 |      (7): nn.SpatialBatchNormalization (4D) (32)
                 |      (8): nn.Dropout(0.500000)
                 |      (9): nn.SpatialConvolution(32 -> 64, 7x7)
                 |      (10): nn.ReLU
                 |      (11): nn.SpatialBatchNormalization (4D) (64)
                 |      (12): nn.Dropout(0.500000)
                 |      (13): nn.SpatialConvolution(64 -> 64, 7x7)
                 |      (14): nn.ReLU
                 |      (15): nn.SpatialBatchNormalization (4D) (64)
                 |      (16): nn.Dropout(0.500000)
                 |      (17): nn.SpatialConvolution(64 -> 64, 7x7)
                 |      (18): nn.ReLU
                 |      (19): nn.SpatialBatchNormalization (4D) (64)
                 |      (20): nn.Dropout(0.500000)
                 |      (21): nn.SpatialConvolution(64 -> 64, 7x7)
                 |      (22): nn.SpatialBatchNormalization (4D) (64)
                 |      (23): nn.Transpose
                 |      (24): nn.Reshape(1x64)
                 |    }
                  `-> (2): nn.Sequential {
                        [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> (8) -> (9) -> (10) -> (11) -> (12) -> (13) -> (14) -> (15) -> (16) -> (17) -> (18) -> (19) -> (20) -> (21) -> (22) -> (23) -> output]
                        (1): nn.SpatialConvolution(3 -> 32, 7x7)
                        (2): nn.ReLU
                        (3): nn.SpatialBatchNormalization (4D) (32)
                        (4): nn.Dropout(0.200000)
                        (5): nn.SpatialConvolution(32 -> 32, 7x7)
                        (6): nn.ReLU
                        (7): nn.SpatialBatchNormalization (4D) (32)
                        (8): nn.Dropout(0.500000)
                        (9): nn.SpatialConvolution(32 -> 64, 7x7)
                        (10): nn.ReLU
                        (11): nn.SpatialBatchNormalization (4D) (64)
                        (12): nn.Dropout(0.500000)
                        (13): nn.SpatialConvolution(64 -> 64, 7x7)
                        (14): nn.ReLU
                        (15): nn.SpatialBatchNormalization (4D) (64)
                        (16): nn.Dropout(0.500000)
                        (17): nn.SpatialConvolution(64 -> 64, 7x7)
                        (18): nn.ReLU
                        (19): nn.SpatialBatchNormalization (4D) (64)
                        (20): nn.Dropout(0.500000)
                        (21): nn.SpatialConvolution(64 -> 64, 7x7)
                        (22): nn.SpatialBatchNormalization (4D) (64)
                        (23): nn.Reshape(64x121)
                      }
                  ... -> output
             }
             (2): nn.MM
             (3): nn.Reshape(121)
             (4): nn.LogSoftMax
           }
       ... -> output
  }
  (2): nn.CAddTable
}

In PyTorch (version 1.0), I defined the same model as follows:

class SiameseNetwork(nn.Module):
    def __init__(self, nb_channels):
        super(SiameseNetwork, self).__init__()
        self.subnet = SubNetwork(nb_channels=nb_channels)

    def forward(self, ll, lr, rl, rr):
        left = self.subnet(ll, lr)
        right = self.subnet(rl, rr)
        return left + right


class SubNetwork(nn.Module):
    def __init__(self, nb_channels):
        super(SubNetwork, self).__init__()
        self.features = FeaturesExtractor(nb_channels=nb_channels, kernel_size=(7, 7))

    def forward(self, left, right):
        left = self.features(left)
        right = self.features(right)
        left = torch.transpose(left, dim0=1, dim1=2)
        left = torch.reshape(left, shape=(-1, 1, 64))
        right = torch.reshape(right, shape=(-1, 64, 121))
        correlation = torch.matmul(left, right)
        correlation = torch.reshape(correlation, shape=(-1, 121))
        probabilities = F.log_softmax(correlation, dim=1)
        return probabilities


class FeaturesExtractor(nn.Module):
    def __init__(self, nb_channels=3, kernel_size=(7, 7)):
        super(FeaturesExtractor, self).__init__()
        self.nb_channels = nb_channels
        self.ksize = kernel_size
        self.dropout_in = nn.Dropout(p=0.2)
        self.dropout = nn.Dropout(p=0.5)

        self.convbnrelu1 = ConvBNReLU(in_dim=self.nb_channels, out_dim=32, ksize=self.ksize)
        self.convbnrelu2 = ConvBNReLU(in_dim=32, out_dim=32, ksize=self.ksize)
        self.convbnrelu3 = ConvBNReLU(in_dim=32, out_dim=64, ksize=self.ksize)
        self.convbnrelu4 = ConvBNReLU(in_dim=64, out_dim=64, ksize=self.ksize)
        self.convbnrelu5 = ConvBNReLU(in_dim=64, out_dim=64, ksize=self.ksize)
        self.convbn6 = ConvBN(in_dim=64, out_dim=64, ksize=self.ksize)

    def forward(self, x):
        y = self.dropout_in(self.convbnrelu1(x))
        y = self.dropout(self.convbnrelu2(y))
        y = self.dropout(self.convbnrelu3(y))
        y = self.dropout(self.convbnrelu4(y))
        y = self.dropout(self.convbnrelu5(y))
        y = self.convbn6(y)
        return y


class ConvBN(nn.Module):
    def __init__(self, in_dim, out_dim, ksize):
        super(ConvBN, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_dim, out_channels=out_dim, kernel_size=ksize)
        self.bn = nn.BatchNorm2d(num_features=out_dim, eps=0.001)

    def forward(self, x):
        return self.bn(self.conv(x))


class ConvBNReLU(nn.Module):
    def __init__(self, in_dim, out_dim, ksize):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_dim, out_channels=out_dim, kernel_size=ksize)
        self.bn = nn.BatchNorm2d(num_features=out_dim, eps=0.001)

    def forward(self, x):
        return self.bn(F.relu(self.conv(x)))

Both models have the same number of parameters.

The problem that I have is that both models have very similar training loss and accuracy, but the PyTorch model does not have the same validation results as the Lua model.
(accuracy-loss)
Lua model:

  • Epoch 1: train (19.6%-8.51), validation (36.5%-8.57)
  • Epoch 2: train (39.6%-7.44), validation (44.5%-7.79)
  • Epoch 100: train(91.8%-4.95), validation (87.0%-5.66)
    PyTorch model:
  • Epoch 1: train (21.7%-11.35), validation (14.49%-12.50)
  • Epoch 2: train(38.6%-8.07), validation (14.1%, 11.95)
  • Epoch 100: train(92.9%-4.78), validation(67.5%, 8.29)

Things I tried:

  • Checked that I had the same inputs, both for training and validation (yes).
  • Both models are in eval mode during validation.
  • Used training data during validation (for PyTorch model) and it did not increase the validation results.
  • Changing from version 1.0 to 1.3 (no difference).

My first thought was that my PyTorch model was overfitting, but that does not explain why the Lua model does have the same problem (with the same data). I am looking for suggestions of where the problem might be.

Thank you,
D-A