Why does data augmentation decrease validation accuracy: pytorch/keras comparison

I have a cnn as below for cifar10:

	self.layer1 = nn.Sequential(
		nn.Conv2d(3, 64, kernel_size= 3),
		nn.BatchNorm2d(64),
		nn.ReLU(),
		nn.MaxPool2d(kernel_size=2))
	self.layer2 = nn.Sequential(
		nn.Conv2d(64, 128, kernel_size= 3),
		nn.BatchNorm2d(128),
		nn.ReLU(),
		nn.MaxPool2d(kernel_size=2))
	self.layer3 = nn.Sequential(
		nn.Conv2d(128, 256, kernel_size= 3),
		nn.BatchNorm2d(256),
		nn.ReLU())
	
	self.layer4 =nn.AvgPool2d(8)
	self.layer5 =nn.Linear(256, num_classes)
	self.layer6 =nn.Softmax(dim=1)
	
def forward(self, x):
	out = self.layer1(x)
	out = self.layer2(out)
	out = self.layer3(out)
	out = self.layer4(out)
	out = out.reshape(out.size(0), -1)
	out = self.layer5(out)
	out = self.layer6(out)
	return out
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay = 0.0005)

I am training this network for 20 epochs, and I use the below data augmentation methods.
1- random crop(32, padding=4)
2- random horizontal flip
3- normalization
4- random affine for horizontal and vertical translation
5- mixup(alpha=1.0)
6- cutout(num_holes=1, size=16)

Each time I add a new data augmentation after normalization(4,5,6), my validation accuracy decreases from 60% to 50%. I know if the model’s capacity is low it is possible.

However, when I train this network on keras for 20 epochs, using the same data augmentation methods, I can reach over 70% validation accuracy.

What am I missing?

Note: in keras implementation convolution and dense layers have L2 kernel regularization, in pytorch implementation only the optimizer has L2. Could that be the reason?

1 Like

I also tried removing weight decay from SGD, and adding it to conv and dense layers manually as given in here in-pytorch-how-to-add-l1-regularizer-to-activations but still the validation accuracy is 50% with all listed regularizations. So it can’t be the weight decay.

I will really appreciate why with keras we can get 20% val_acc difference, but not with pytorch

If you use nn.CrossEntropyLoss you should pass the logits into this criterion rather than the probabilities from nn.Softmax.
Could you remove self.layer6 and try it again?

2 Likes

Thank you for your answer! I am trying now, but I don’t understand the reason honestly.

Shouldn’t I use softmax for multi class distribution, and cross entropy to undo the exponential with log? Could you please help me to understand?

Also that keras model has softmax activation on its dense layer, and categorical cross entropy as its loss function.

Internally nn.CrossEntropyLoss will call nn.LogSoftmax on the input and then use nn.NLLLoss (negative log likelihood loss).
So you can remove the nn.Softmax layer and pass the logits to nn.CrossEntropyLoss or alternatively you could use nn.LogSoftmax() as the last layer and use nn.NLLLoss as your criterion.
The reason for this is that calculating log of the softmax might be numerically unstable, thus nn.LogSoftmax is preferred.

Okay I get it now, thank you. Before this change 24 different models’ average validation accuracy was 48,4. After the change 8 models’ average accuracy is 52.78. On the other hand keras model’s average accuracy for 20 models is 64.4. And if I don’t use mixup, cutout, or random affine, pytorch models can get around 60%.

There must still be something that I am missing.

These are the data augmentation methods, maybe the problem is here:

class Cutout(object):
    """
    Randomly mask out one or more patches from an image.
    Args:
        n_holes (int): Number of patches to cut out of each image.
        length (int): The length (in pixels) of each square patch.
    """
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        """
        Args:
            img (Tensor): Tensor image of size (C, H, W).
        Returns:
            Tensor: Image with n_holes of dimension length x length cut out of it.
        """
        h = img.size(1)
        w = img.size(2)

        mask = np.ones((h, w), np.float32)

        for n in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = int(np.clip(y - self.length / 2, 0, h))
            y2 = int(np.clip(y + self.length / 2, 0, h))
            x1 = int(np.clip(x - self.length / 2, 0, w))
            x2 = int(np.clip(x + self.length / 2, 0, w))

            mask[y1: y2, x1: x2] = 0.

        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img = img * mask

        return img
def mixup_data(x, y, alpha=1.0, is_cuda=True):
    lam = np.random.beta(alpha, alpha) if alpha > 0. else 1.
    batch_size = x.size()[0]
    index = randperm(batch_size).cuda() if is_cuda else randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

And my prepare_data() method:

def prepare_data(batch_size=128, valid_frac=0.1, manual_seed=0):
    n_holes = 1
    length = 16

    mean = [x / 255.0 for x in [125.3, 123.0, 113.9]]
    std = [x / 255.0 for x in [63.0, 62.1, 66.7]]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(degrees=0,translate=(0.125, 0.125)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    
    train_transform.transforms.append(Cutout(n_holes=n_holes, length=length))
    
    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, transform=train_transform, download=True)
    valid_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, transform=train_transform, download=True)
    test_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=False, transform=test_transform, download=True)

And in the original script, in train() I use mixup as below (alpha=1.0):

 def train(self, x_val, y_val):
        x = Variable(x_val, requires_grad=False)
        y = Variable(y_val, requires_grad=False)

        x, y_a, y_b, lam = utils.mixup_data(x, y, self.alpha)

        x = Variable(x, requires_grad=False)
        y_a = Variable(y_a, requires_grad=False)
        y_b = Variable(y_b, requires_grad=False)

        self.optimizer.zero_grad()

        output = self.forward(x)

        loss_loc = lam * self.loss(output, y_a) + (1 - lam) * self.loss(output, y_b)

        loss_loc.backward(retain_graph=True)
      
        self.optimizer.step()

Your code looks generally good!
Could you try to apply the same weight initializations that are used in Keras to compare the models?
Here is a small example.
Also, could you post the Keras code, as there still might be some small differences?

Some minor issue:

  • Variables are deprecated and you can use tensors directly since PyTorch 0.4.0
  • It’s generally recommended to call the model directly instead of forward. You could change self.forward(x) to self(x).

There is no specific weight initialization for the keras model. The source code says conv and dense layer kernels are initialized with glorot_uniform. So I will try to implement glorot_uniform now, thanks! (Is it the same as Xavier?)

Here is a scratch of keras model, I am not allowed to share it as it is.

Conv2D(filters = '64',
        kernel_size = 3,
        activation = None,
        padding='same',
        kernel_regularizer = regularizers.l2(weight_decay))
BatchNormalization()
Activation('relu')
MaxPooling2D( pool_size = 2)

Conv2D(filters = '128',
        kernel_size = 3,
        activation = None,
        padding='same',
        kernel_regularizer = regularizers.l2(weight_decay))
BatchNormalization()
Activation('relu')
MaxPooling2D( pool_size = 2)

Conv2D(filters = '256',
        kernel_size = 3,
        activation = None,
        padding='same',
        kernel_regularizer = regularizers.l2(weight_decay))
BatchNormalization()
Activation('relu')

GlobalAveragePooling2D()
Dense( units = 10,
        activation = 'softmax',
        kernel_regularizer = regularizers.l2(weight_decay) 
        )

opt_algo = optimizers.SGD(lr = 0.01, momentum = 0.9)
keras_model.compile(optimizer = opt_algo,
                    loss = 'categorical_crossentropy',
                    metrics = ['accuracy'])


train_datagen_pre = ImageDataGenerator(
                featurewise_center=False,  
                samplewise_center=False, 
                featurewise_std_normalization=False, 
                samplewise_std_normalization=False,  
                zca_whitening=False, 
                rotation_range=0, 
                width_shift_range=0.125,  
                height_shift_range=0.125,  
                horizontal_flip=True,
        vertical_flip=False,
        preprocessing_function = utils.cutout) 
train_datagen_pre.fit(X_train)

    
train_datagen = MixupGenerator(X_train, Y_train, batch_size=batch_size, alpha=1.0, datagen=train_datagen_pre)()

batch_size = 128

keras_model.fit_generator(generator = train_datagen,
                        steps_per_epoch=X_train.shape[0] // batch_size,                    
                        epochs=20, 
                        validation_data=test_datagen.flow(X_test,Y_test, batch_size =batch_size),
                        validation_steps = X_test.shape[0] //batch_size)

Thanks for the code!
Yes, Xavier Glorot introduces the initialization scheme. Some frameworks use his first name, while others prefer his last name.
Besides the potential weight init difference, you are not using any padding in your PyTorch model.
For kernel_size=3 and default values for stride, dilation, etc. you should use padding=1.

I use padding=1 normally, I just forgot to add it there :slight_smile:

Now I am testing the weight initialized version. Also, is there a way to disable the weight decay for batch_normalization layer’s learnable parameters?

I implemented it as below based on https://stackoverflow.com/questions/44641976/in-pytorch-how-to-add-l1-regularizer-to-activations. But I am not sure if it is correct (I did’t use out1, out5, and out5?) and if there is a cleaner way to do it.

for i, (images, labels) in enumerate(trainloader):
			images = images.cuda()
			labels = labels.cuda()
			lambda2 = 0.0005
			# Forward pass
			out, out1, out5, out9, out13  = model(images)
			loss = criterion(out, labels)

			all_1_params = torch.cat([x.view(-1) for x in model.layer1.parameters()])
			all_5_params = torch.cat([x.view(-1) for x in model.layer5.parameters()])
			all_9_params = torch.cat([x.view(-1) for x in model.layer9.parameters()])
			all_13_params = torch.cat([x.view(-1) for x in model.layer13.parameters()])

			l2_regularization_1 = lambda2 * torch.norm(all_1_params, 2)
			l2_regularization_5 = lambda2 * torch.norm(all_5_params, 2)
			l2_regularization_9 = lambda2 * torch.norm(all_9_params, 2)
			l2_regularization_13 = lambda2 * torch.norm(all_13_params, 2)

			
			loss_all = loss+l2_regularization_1+l2_regularization_5+l2_regularization_9+l2_regularization_13
			# Backward and optimize
			optimizer.zero_grad()
			loss_all.backward()
			optimizer.step()

Hey thank you so much for your support!

I could get max 69.9% validation accuracy with

  1. random horizontal flip
  2. normalization
  3. random affine for horizontal and vertical translation
  4. mixup(alpha=1.0)
  5. cutout(num_holes=1, size=16)

Random crop was decreasing val_acc. I guess I shouldn’t use it with cutout.

Weight initialization didn’t effect substantially. Here is what I did:

  1. Remove softmax layer
  2. Remove weight decay from optimizer
  3. Decay only conv2d and linear layer weights manually

When we just add weight decay to the optimizer, it decays all differentiable parameters including biases and learnable parameters of batch normalization layers. In my case this, plus using softmax layer had catastrophic effects.

Keras model had kernel initializers on conv2d and linear layers, hence it didn’t have such a problem. It also has softmax layer.

Awesome, I’m glad it’s working now!

Just in case I made confusing statements: PyTorch modules also have default initializers, which are different to the ones Keras uses by default. :wink: