Lower accuracy in PyTorch vs. rest

I have been working on a project which runs the same VGG-based CNN on a common-dataset (CIFAR-10), and unfortunately have not really been able to match the performance of other frameworks with PyTorch.

Chainer reaches 0.78 after 4min 16s, CNTK reaches 0.77 after 2min 48s. However, PyTorch only averages 0.73 and after nearly 6 minutes.

For example, here is a PyTorch extract:

def create_symbol():
    class SymbolModule(nn.Module):
        def __init__(self):
            super(SymbolModule, self).__init__()
            self.conv1 = nn.Conv2d(3, 50, kernel_size=(3, 3), padding=(1, 1))
            self.conv2 = nn.Conv2d(50, 50, kernel_size=(3, 3), padding=(1, 1))
            self.conv3 = nn.Conv2d(50, 100, kernel_size=(3, 3), padding=(1, 1))
            self.conv4 = nn.Conv2d(100, 100, kernel_size=(3, 3), padding=(1, 1))
            # feature map size is 8*8 by pooling
            self.fc1 = nn.Linear(100*8*8, 512)
            self.fc2 = nn.Linear(512, N_CLASSES)

        def forward(self, x):
            x = F.relu(self.conv2(F.relu(self.conv1(x))))
            x = F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2))
            x = F.dropout(x, 0.25)
            
            x = F.relu(self.conv4(F.relu(self.conv3(x))))
            x = F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2))
            x = F.dropout(x, 0.25)
                       
            x = x.view(-1, 100*8*8)   # reshape Variable
            x = F.dropout(F.relu(self.fc1(x)), 0.5)
            x = self.fc2(x)
            return F.log_softmax(x)
    return SymbolModule()

def init_model(m):
    # Implementation of momentum:
    # v = \rho * v + g \\
    # p = p - lr * v
    opt = optim.SGD(m.parameters(), lr=LR, momentum=MOMENTUM)
    return opt

Chainer extract:

class SymbolModule(chainer.Chain):
    def __init__(self):
        super(SymbolModule, self).__init__(
            conv1=L.Convolution2D(3, 50, ksize=(3,3), pad=(1,1)),
            conv2=L.Convolution2D(50, 50, ksize=(3,3), pad=(1,1)),      
            conv3=L.Convolution2D(50, 100, ksize=(3,3), pad=(1,1)),  
            conv4=L.Convolution2D(100, 100, ksize=(3,3), pad=(1,1)),  
            # feature map size is 8*8 by pooling
            fc1=L.Linear(100*8*8, 512),
            fc2=L.Linear(512, N_CLASSES),
        )

def __call__(self, x):
    h = F.relu(self.conv2(F.relu(self.conv1(x))))
    h = F.max_pooling_2d(h, ksize=(2,2), stride=(2,2))
    h = F.dropout(h, 0.25)
    
    h = F.relu(self.conv4(F.relu(self.conv3(h))))
    h = F.max_pooling_2d(h, ksize=(2,2), stride=(2,2))
    h = F.dropout(h, 0.25)       
    
    h = F.dropout(F.relu(self.fc1(h)), 0.5)
    return self.fc2(h)

def init_model(m):
    optimizer = optimizers.MomentumSGD(lr=LR, momentum=MOMENTUM)
    optimizer.setup(m)
    return optimizer

Since I’m comparing the same mathematical operations (albeit on a randomly initialised matrix) I believe I should get the same accuracies when averaged across runs (roughly). So I’m not sure what else there is to set with PyTorch. I have experimented a bit with different weight initialisations (other frameworks use glorot/xavier uniform) and checking gradient-clipping params but this hasn’t made any real difference.

For example:

    self.conv1 = nn.Conv2d(3, 50, kernel_size=3, padding=1)
    init.xavier_uniform(self.conv1.weight, gain=np.sqrt(2.0))
    init.constant(self.conv1.bias, 0.1)

Well first off you should do max pooling before relu as they are mathematically equivalent either order but it’s nearly 40% computationally more efficient to do maxpool before relu.

So look like this:
F.relu(F.max_pooling2d(nn.conv2d(x)))

Are u purposely only maxpooling twice or do want 4times on 4conv2d layers?

There are also a lot of potential differences in the code you haven’t published, the training/validation loops, dataset handling, image preprocessing, etc.

You’ve got a log_softmax on the output of your pytorch model. I assume you are using NLLLoss and not CrossEntropy (which includes the log softmax like Chainer’s SoftmaxCrossEntropy)?

Thanks for the feedback. I have moved relu after max_pool2d and indeed was applying softmax on the output twice (didn’t realise CrossEntropyLoss() includes it). The updated script is here however the script still gets 0.72 accuracy and takes around 6 minutes, so weirdly not much has chained.

I was purposefully only max-pooling twice because my feature-maps are already quite small (and do the same thing with Chainer).

Ross, the data is the same for all the frameworks (it comes from a common function, instead of using the library’s versions of it). If it’s easier here are the relevant code extracts:

def create_symbol():
    class SymbolModule(nn.Module):
        def __init__(self):
            super(SymbolModule, self).__init__()
            self.conv1 = nn.Conv2d(3, 50, kernel_size=(3, 3), padding=(1, 1))
            self.conv2 = nn.Conv2d(50, 50, kernel_size=(3, 3), padding=(1, 1))
            self.conv3 = nn.Conv2d(50, 100, kernel_size=(3, 3), padding=(1, 1))
            self.conv4 = nn.Conv2d(100, 100, kernel_size=(3, 3), padding=(1, 1))
            # feature map size is 8*8 by pooling
            self.fc1 = nn.Linear(100*8*8, 512)
            self.fc2 = nn.Linear(512, N_CLASSES)

        def forward(self, x):
            x = self.conv2(F.relu(self.conv1(x)))
            # Apply relu after max-pool
            x = F.relu(F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2)))
            x = F.dropout(x, 0.25)
            
            x = self.conv4(F.relu(self.conv3(x)))
            # Apply relu after max-pool
            x = F.relu(F.max_pool2d(x, kernel_size=(2, 2), stride=(2, 2)))
            x = F.dropout(x, 0.25)
                       
            x = x.view(-1, 100*8*8)   # reshape Variable
            x = F.dropout(F.relu(self.fc1(x)), 0.5)
            return self.fc2(x)
    return SymbolModule()

def init_model(m):
    # Implementation of momentum:
    # v = \rho * v + g \\
    # p = p - lr * v
    opt = optim.SGD(m.parameters(), lr=LR, momentum=MOMENTUM)
    # Combines softmax output with negative log-likelihood
    criterion = nn.CrossEntropyLoss()
    return opt, criterion

x_train, x_test, y_train, y_test = cifar_for_library(channel_first=True)
y_train = y_train.astype(np.int64)
y_test = y_test.astype(np.int64)

sym = create_symbol()
sym.cuda() # CUDA!
optimizer, criterion = init_model(sym)
sym.train()  
for j in range(EPOCHS):
    for data, target in yield_mb(x_train, y_train, BATCHSIZE, shuffle=True):
        # Get samples
        data = Variable(torch.FloatTensor(data).cuda())
        target = Variable(torch.LongTensor(target).cuda())
        # Init
        optimizer.zero_grad()
        # Forwards
        output = sym(data)
        # Loss
        loss = criterion(output, target)
        # Back-prop
        loss.backward()
        optimizer.step()