DeepFont in pytorch, not converging

Hi!

I’m implementing the original DeepFont paper in PyTorch (several modified versions can be found implemented in Keras). Unfortunately, despite my attempts, I’m unable to make the model converge. I’ve tried to debug the model issue by inducing it to overfit on a small dataset: unfortunately, this didn’t work either. The train loss (MSE) starts at around 0.0016 and keeps orbiting around that values, despite changes to the learning rate, epochs or number of samples.

This is how I’m defining the autoencoder that takes the images patches as inputs:

class DeepFontAutoencoder(nn.Module):
    def __init__(self, num_classes):
        super(DeepFontAutoencoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=12, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            # See https://stackoverflow.com/a/58207809/261698
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=12, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        debug_shapes = False
        # ConvTranspose2d ~= conv2d + upsamlpe:
        # https://discuss.pytorch.org/t/upsample-conv2d-vs-convtranspose2d/138081
        encoded = self.encoder(x)
        debug_shapes and print(f"Encoder output shape: {encoded.shape}")
        decoded = self.decoder(encoded)
        debug_shapes and print(f"Decoder output shape: {decoded.shape}")

        return decoded

And this is how I’m using the autoencoder (the intent is to use the encoded patches to perform the classification):

class DeepFontEncoded(nn.Module):
    def __init__(self, trained_autoencoder, num_classes):
        super(DeepFontEncoded, self).__init__()

        """
        # Cross-domain subnetwork layers (Cu)
        # This is coming from the encoder.
        self.autoencoder = trained_autoencoder
        # Make sure we don't train the autoencoder again.
        for param in self.autoencoder.parameters():
            param.requires_grad = False
        """
        self.ae_encoder = trained_autoencoder.encoder
        for param in self.ae_encoder.parameters():
            param.requires_grad = False
        
        # Domain-specific layers (Cs)
        self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)

        self.flatten = nn.Flatten()

        fcn_size = 4096
        self.fc1 = nn.Linear(256 * 12 * 12, fcn_size)
        self.drop1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(fcn_size, fcn_size)
        self.drop2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(fcn_size, num_classes)

    def forward(self, x):
        # Cu Layers
        with torch.no_grad():
            #x = self.autoencoder.encoder(x)
            x = self.ae_encoder(x)
        #print(f"Output Cu: {x.shape}")

        # Cs Layers
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        #print(f"Output Cs: {x.shape}")

        x = self.flatten(x)
    
        x = F.relu(self.fc1(x))
        x = self.drop1(x)
        x = F.relu(self.fc2(x))
        x = self.drop2(x)
        x = F.softmax(self.fc3(x), dim=1)
        
        return x

I’m able to easily train, overfit and test the autoencoder on its own. It produces a reasonable approximation of the patches I feed to it. However, when I train with the following elements:

model = DeepFontEncoded(trained_autoencoder=autoenc_model, num_classes=full_data_num_classes)
model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=0.0005, nesterov=True)
criterion = nn.MSELoss()

No convergence happens. How can I understand what’s happening or debug this further?

Thanks!

nn.MSELoss with softmax outputs looks quite uncommon. Is this how you were able to make the model converge in Keras or did you use another criterion? I’m not familiar with this model, so your use case might be valid.

Thank you very much for your quick reply! I understand this looks a bit weird (MSE + SoftMax), I’m trying to replicate exactly the results from (my understanding of) the paper before moving on to my own improvements. The authors mention using SoftMax at the end of the network and at least MSE for the autoencoder part (although they don’t explicitly mention what they use for the whole network, I’m assuming it’s the same).

But yes, this is how I made this converge in Keras. I’m able to get to a 93% validation accuracy fairly easy with this Keras network (even without using the pretrained autoencoder weights!):

def create_model(num_classes):
  model=Sequential()

  # Cu Layers 
  model.add(ZeroPadding2D(padding=(1, 1)))
  model.add(Conv2D(64, kernel_size=(12, 12), strides=2, padding='valid', activation='relu', input_shape=(105,105,1)))
  model.add(BatchNormalization())
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), padding='same', activation='relu'))
  model.add(BatchNormalization())
  model.add(MaxPooling2D(pool_size=(2, 2)))

  #Cs Layers
  model.add(Conv2D(256, kernel_size=(3, 3), padding='same', activation='relu'))
  model.add(Conv2D(256, kernel_size=(3, 3), padding='same', activation='relu'))
  model.add(Conv2D(256, kernel_size=(3, 3), padding='same', activation='relu'))

  model.add(Flatten())
  model.add(Dense(4096, activation='relu'))
  model.add(Dropout(0.5))
  model.add(Dense(4096,activation='relu'))
  model.add(Dropout(0.5))

  model.add(Dense(num_classes, activation='softmax'))
 
  return model

And this is how I train:

batch_size = 128
epochs = 10
model= create_model(num_classes=num_classes)
model.build(input_shape=(None, 105, 105, 1))
model.summary()

optimizer = optimizers.SGD(learning_rate=0.01, decay=0.0005, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])