Can't recover Keras results using Pytorch

I have been working on some regression task using CNN. At the moment I’m testing some
architecture (Alexnet like) using Keras. Over only two epochs, one can already notice some correlation (still weak though) between the prediction and the actual value of the target. However, with Pytorch
I struggle to get some trend in the results even with ten epochs. I must admit I’m quite new in Pytorch and this is the very first time I’m using it. Please see below the Keras implementation:

x = Conv2D(16, (7, 7), activation='relu', padding='same', strides=2)(input_img)
x = MaxPooling2D((3, 3), padding='same', strides=2)(x)
x = BatchNormalization()(x)
x = Conv2D(32, (5, 5), activation='relu', padding='same', strides=2)(x)
x = MaxPooling2D((3, 3), padding='same', strides=2)(x)
x = BatchNormalization()(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same', strides=2)(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same', strides=2)(x)
x = MaxPooling2D((3, 3), padding='same', strides=2)(x)
x = BatchNormalization()(x)
x = Flatten()(x)
x = Dense(512, activation='tanh')(x)
x = Dropout(0.5)(x)
x = Dense(512, activation='tanh')(x)
x = Dropout(0.5)(x)
x = Dense(9, activation=None)(x)

and Pytorch:

class AlexNetTest(nn.Module):
    

    def __init__(self):
        super(AlexNetTest, self).__init__()
            
        self.C1 = nn.Conv2d(1, 8, 7, stride = 2, padding = 3)
        torch.nn.init.xavier_uniform_(self.C1.weight)
        torch.nn.init.constant_(self.C1.bias, 0)
        self.S2 = nn.MaxPool2d(3, stride = 2, padding = 1) # first pooling
        self.BC1S2 = nn.BatchNorm2d(8, momentum = 0.99, eps = 1.0e-3)

        self.C3 = nn.Conv2d(8, 16, 5, stride = 2, padding = 2)
        torch.nn.init.xavier_uniform_(self.C3.weight)
        torch.nn.init.constant_(self.C3.bias, 0)
        self.S4 = nn.MaxPool2d(3, stride = 2, padding = 1) # second pooling
        self.BC3S4 = nn.BatchNorm2d(16, momentum = 0.99, eps = 1.0e-3)

        self.C5 = nn.Conv2d(16, 32, 3, stride = 2, padding = 1)
        torch.nn.init.xavier_uniform_(self.C5.weight)
        torch.nn.init.constant_(self.C5.bias, 0)
        self.C6 = nn.Conv2d(32, 32, 3, stride = 2, padding = 1)
        torch.nn.init.xavier_uniform_(self.C6.weight)
        torch.nn.init.constant_(self.C6.bias, 0)
        self.C7 = nn.Conv2d(32, 16, 3, stride = 2, padding = 1)
        torch.nn.init.xavier_uniform_(self.C7.weight)
        torch.nn.init.constant_(self.C7.bias, 0)
        self.S8 = nn.MaxPool2d(3, stride = 2, padding = 1) # third pooling
        self.BC7S8 = nn.BatchNorm2d(16, momentum = 0.99, eps = 1.0e-3)
        

        self.F8  = nn.Linear(16 * 1 * 1, 512)
        torch.nn.init.xavier_uniform_(self.F8.weight)
        torch.nn.init.constant_(self.F8.bias, 0)
        self.F9  = nn.Linear(512, 512)
        torch.nn.init.xavier_uniform_(self.F9.weight)
        torch.nn.init.constant_(self.F9.bias, 0)
        self.Out = nn.Linear(512, 9)
        torch.nn.init.xavier_uniform_(self.Out.weight)
        torch.nn.init.constant_(self.Out.bias, 0)
        


    def forward(self, x):
        x = F.relu(self.BC1S2(self.C1(x)))
        x = self.S2(x)

        x = F.relu(self.BC3S4(self.C3(x)))
        x = self.S4(x)
       

        x = F.relu(self.C5(x))
        x = F.relu(self.C6(x))
        x = F.relu(self.BC7S8(self.C7(x)))
        x = self.S8(x)

        x = x.view(-1, 16 * 1 * 1)
       
        x = F.dropout(torch.tanh(self.F8(x)), p = 0.5, training = True)
        x = F.dropout(torch.tanh(self.F9(x)), p = 0.5, training = True)
        x = self.Out(x)
        return x

In both cases, I use the same optimizer (RMSprop(lr = 0.001) and torch.optim.RMSprop(net.parameters(), lr = 0.001, alpha = 0.9)). I wonder
if you can help spot the issue with my Pytorch implementation.

Thanks a lot in advance.

There might be some differences and issues in your PyTorch code:

  • it seems you are using half the number of filters in each conv layer compared to the Keras model
  • in your Keras implementation you are applying: conv - relu - pool - bn, while in Pytorch you are using: conv - bn - relu - pool
  • Dropout is always activated in your PyTorch model. Use the nn.Dropout module or pass training=self.training

@ptrblck Thanks for your reply. I changed the number of filters and also the sequence conv - bn - relu - pool to match the Keras code. As for the dropout, I used nn.Dropout then explicitly go to training mode by
setting net = net.train() before looping through the batches and setting net = net.eval() when validating at the end of each epoch like the following:

x = nn.Dropout(p = 0.5)(torch.tanh(self.F8(x)))
......
net = net.train()
for local_batch, local_labels in training_generator:
......
net = net.eval()
with torch.no_grad():
.....

For each epoch in Pytorch I get:

epoch: 0 | train loss: 0.6084
epoch: 1 | train loss: 0.4554

whereas in Keras I get:

epoch: 0 | loss: 0.9154
epoch: 1 | loss: 0.2524

Over two epochs I still don’t recover Keras results.

I have tested my code just now and I think the issue (amongst the other ones in the code I posted) is the dataloader I created in Pytorch:

class Dataset(data.Dataset):

	'''
    This class is a simple data loader that is supposed to allow one to
    load data
	'''
	def __init__(self, filename, height, width, types_):
		self.data = np.load(filename)
		self.feats, self.labels = self.data[types_+'_x'], self.data[types_+'_l']
		self.height, self.width = height, width
		self.transform = transforms.ToTensor()

	def __len__(self):
		return len(self.labels)

	def __getitem__(self, index):
		img_as_np  = np.asarray(self.feats[index]).reshape((self.height, self.width))
		img_as_img = Image.fromarray(img_as_np)
		img_as_img = img_as_img.convert('L')
		return self.transform(img_as_img), self.labels[index]

which I think I did right but appears to have some issue since when I load the batches by simply creating variable like:

rad = np.random.randint(len(X_train), size=32)
local_batch  = Variable(torch.from_numpy((X_train[rad].reshape((32,1,ndim,ndim))).astype(np.float32))) 
local_labels = Variable(torch.from_numpy(y_train[rad].astype(np.float32)))

I get for each epoch:
epoch: 0 | train loss: 0.7168
epoch: 1 | train loss: 0.2158

which I think is now in the right ballpark, and I also see the trend in the results over two epochs.
One last thing which I would like to mention is that each epoch in Keras takes about 3 mins whereas
in Pytorch it is 4-5 mins. Is this expected or is it just due to my implementation? I currently use cpu for these tests for now.

It looks like you’ve added the Dropout module (and its initialization) into your forward method.
This still won’t work and you should register it as a module in your __init__ method:

class AlexNetTest(nn.Module):
    def __init__(self):
        super(AlexNetTest, self).__init__()
        ...
        self.drop1 = nn.Dropout(p=0.5)

    def forward(self, x):
        ...
        x = self.drop1(torch.tanh(self.F8(x))

In your Dataset you are normalizing the data to the range [0, 1], while in your small numpy script you are loading the data from the plain numpy arrays. Are you preprocessing the images somehow in your Keras script? Could you check the range of your data using the Dataset and your manual approach?

As a small side note: Variables are deprecated since PyTorch 0.4.0. If you are using a newer version, you can just use tensors instead.
This might also be the issue for the performance difference, since the performance is increased in newer versions.

Thanks a lot for your suggestion. I have now initialised the dropout in __init__.
In the Keras script it’s also plain numpy arrays (without any preprocessing) which I reshaped. In fact, I switched to plain numpy arrays in my Pytorch code to see if I get results similar to the Keras script. It turned out that it sort of did the job.

Surprisingly, the data in Dataset are all zeros whereas in the manual approach it is between [0.0, 0.53]. Something is definitely wrong with my Dataset.

I also changed Variable(torch.from_numpy()) to simply torch.tensor() and got consistent results, altough the time it takes for one epoch is ~ 5mins.

As a whole, with all the changes implemented and using manual approach to load the batches, I also get the trend in my results in Pytorch over two epochs (like in Keras), although it takes a bit longer. It could be though that it is the way I implement the validation at each epoch that takes quite some time. I’m also still left with this bug which I still need to find in my Dataset.

I think the transformation to L mode will destroy your data.
If you don’t really need to transform your numpy array to a PIL.Image, you could directly use:

x = torch.from_numpy(img_as_np)

Otherwise, use the F mode for your transformation:

img_as_img = Image.fromarray(img_as_np, 'F')

and remove the transformation to L.

Indeed I tried the F mode for the transformation and got consistent results. Many thanks.