Any reason why the following PyTorch (3s/epoch) code is so much slower than MXNet's version (~0.6s/epoch)?

from __future__ import print_function
import time
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms

tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)] 
for i in range(len(tableau20)):    
    r, g, b = tableau20[i]    
    tableau20[i] = (r / 255., g / 255., b / 255.)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                    help='input batch size for training (default: 256)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=30, metavar='N',
                    help='number of epochs to train (default: 30)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.0005, metavar='W',
                    help='SGD weight decay (default: 0.0005)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--vis_path', type=str, default="visualizations/color6", metavar='S',
                    help='path to save your visualization figures')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")


kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.Linear):
        nn.init.constant_(m.bias, 0)

 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 256)
        self.fc2 = nn.Linear(256, 2)
        self.fc3 = nn.Linear(2, 10)

    def forward(self, x, y=None):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2, stride=2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2, stride=2)
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        features = self.fc2(x)
        x = self.fc3(features)
        return F.log_softmax(x, dim=1)

model = Net().to(device)
model.apply(weights_init)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

def train(epoch):
    model.train()
    for batch_idx, (batch_data, batch_labels) in enumerate(train_loader):
        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
        optimizer.zero_grad()
        batch_scores = model(batch_data)
        loss = F.nll_loss(batch_scores, batch_labels)
        
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(batch_data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_data, batch_labels in test_loader:
            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
            batch_scores = model(batch_data)
            test_loss += F.nll_loss(batch_scores, batch_labels, size_average=False).item() # sum up batch loss
            pred = batch_scores.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(batch_labels.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 
for epoch in range(1, args.epochs + 1):
    scheduler.step()
    start = time.time()
    train(epoch)
    end = time.time()
    print('Total time taken: {:.2f}s\n'.format(end-start))
    test()

The mxnet version is here https://github.com/luoyetx/mx-lsoftmax. I run the mxnet version without the lsoftmax layer using this command python2 mnist.py --train --no-lsoftmax --gpu 0 --batch_size 256.

Can I know what I am doing wrong? Or is pytorch just so much slower?

I am using GTX1080Ti with CUDA 9 and CUDNN 7.

For optimal speed, I would suggest that you omit the softmax in the model, instead , return the output of fc3 and use nn.CrossEntropyLoss instead of F.nll_loss.
Furthermore, you should be aware that a task that takes between 0.6 - 3 secs per epoch is not a true test of speed. A number of performance optimizations might incur a slight overhead which would manifest strongly in a minimal task, however, for a more expensive task, the performance would be significantly faster and the overhead would have no effect. I suggest you use CIFAR 10 with sufficient number of layers when doing performance comparisons.

It is still 3s after using your suggestion.

The argument that I need more layers for benchmark doesnt make sense. It is either I am doing something wrong or pytorch is just inherently slow.

Hi,

I think the point was that for such a small model, the overhead of running the python code is actually going to be noticeable compared to something that runs on pure c++.
The printing every 10 batch can slow down the process quite a lot as well.
Finally, even though it should be minimal here, you need to properly synchronize the gpu with torch.cuda.synchronize() to get accurate time measurements of gpu compute.

1 Like
  1. Isn’t the mxnet code also python? It is printing using python. Does that mean that there is nothing that I can do to narrow the gap for this code?
  2. How can I synchronize the time? Do you mind showing an example?
  1. I did not look in details but the repo you linked contains a full cuda implementation of the code and the readme contains a discussion of python vs c++ code here.
  2. For cuda, the whole API is asynchronous, so at any point in you python code, there might be stuff running from the previous lines. To get proper timings you should always call in the following order (not that adding many cuda sync can slow down you process):
torch.cuda.synchronize()
current_time = time.time()

Yup the repo contains cuda implementation of a layer(lsoftmax) but I did not use the layer. Like I said, I run the code using the command python2 mnist.py --train --no-lsoftmax --gpu 0 --batch_size 256.

Hence it is using a normal fully connected layer in the last layer.

Also I tried removing the printing lines, and it does not make much speed difference, still 3s.

Different python interfaces works in different ways. I encourage you to compare the performance of raw mxnet Python code and the mxnet Gluon api. Despite both are Python, performance differs. For instance, using the Gluon api, you can optionally compile a model for better speed, however, if the model is very shallow, in my experience, the compiled version is often slower. Yet, both are python.
I suggest you try larger models and base your observations on those.

Ok I will try that. Can I conclude that for pytorch, that’s the fastest I can get? 5x slower?

It is quite pointless to compare with a slower version of another framework. I only care about the fastest possible speed.

I can’t tell which is faster, however, in more realistic deep learning experiments which typically involves larger models, Pytorch is highly comparable to the fastest frameworks. There is a lot more to benchmarking deep learning models. From an engineering perspective, consider raw java and jni, for basic tasks, the overhead of calling c code from Java via jni would cause jni to run more slowly than pure java, however, for more compute intensive tasks, jni would outperform pure java code. Hence, benchmarks must be at large scales especially in deep learning where the whole point of the GPU is to perform highly compute intensive tasks.

Can I know for this code, what is the overhead that you are referring to?

And why other frameworks dont have overhead, only pytorch has?

Different frameworks make different design choices engineering wise.
For example pytorch works with dynamic graphs that allow you to be more flexible in what you forward pass can do wrt branching for example. And use a pure imperative declaration of the model. This has the downside that your python “forward” method is called every time you perform a forward pass in your network. So every time you make a forward pass, you have to pay this cost on top of the cost of running the network itself.
On the other hand, this allows you to implement easy python logic and get proper python errors when running your model.
The point that johnolafenwa was making is that if your model if big enough, the cost of executing the python code becomes negligible compared to the network execution itself (which is the case for “large” networks).

And there are many other implementation details that can make the performance vary for different use cases.

So the conclusion you can draw from your experiment is that training this specific network on mnist with this code (that looks good) is slower with pytorch than with the mxnet implementation you linked.
It is expected that all the frameworks will be better for some workloads than others.

From my understanding, the general conclusion is that:

  • For large nets, what is important is the backend: you should use cuda and cudnn for best performance.
  • A static graph should be faster (implementation details can change that), and more especially for small workloads. See how pytorch is planning to address that in here
1 Like

Not only pytorch, take a look at this IBM Benchmark of various high performance scientific libraries on the LU factorization problem, as you would see from the results, the performance on small sized sectors can be highly misleading, you really should take a look at https://github.com/soumith/convnet-benchmarks for a more comprehensive benchmark of deep learning frameworks

Thanks for the clearer explanation. Pytorch 1.0 is really exciting, that would really bridge the deep learning workflow end-to-end.

MNIST is incredibly small and not very representative of other datasets. So things that will make your MNIST training the fastest possible won’t be the same things that make say ImageNet training the fastest possible.

Here are some recommendations:

  1. Increase the number of workers. Currently, you’re very data loading bound.
  2. Set torch.backends.cudnn.benchmark = True. (See What does torch.backends.cudnn.benchmark do?)

Really, for MNIST, you should just put the entire dataset on the GPU to start and do your data normalization once. But that doesn’t transfer to other, larger datasets.

Is there any rule of thumb on how to set the number of workers? It doesn’t seem like the more it is the faster. I tried between 1-20 and it seems like 6 is the fastest. 10 is faster than 20.

Is this behavior normal?

Do you mind showing how to load the entire MNIST into gpu and train? pytorch tutorial doesn’t seem to have that.

I have successfully reduce the time to 1s with num_worker=6 but is it still 50% slower than mxnet.

It is quite interesting that when I use a bigger model, the difference is even worse, which contradicts what everyone here is saying about python’s overhead cost. For the above model, I have successfully reduce the time with a larger number of worker (6) (still slightly slower). However, for this model,

from __future__ import print_function
import time
import math
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
#from large_margin_softmax_linear import LargeMarginSoftmaxLinear
from experiment import LargeMarginSoftmaxLinear

torch.set_printoptions(precision=20)

tableau20 = [(31, 119, 180), (174, 199, 232), (255, 127, 14), (255, 187, 120),    
             (44, 160, 44), (152, 223, 138), (214, 39, 40), (255, 152, 150),    
             (148, 103, 189), (197, 176, 213), (140, 86, 75), (196, 156, 148),    
             (227, 119, 194), (247, 182, 210), (127, 127, 127), (199, 199, 199),    
             (188, 189, 34), (219, 219, 141), (23, 190, 207), (158, 218, 229)] 
for i in range(len(tableau20)):    
    r, g, b = tableau20[i]    
    tableau20[i] = (r / 255., g / 255., b / 255.)

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=256, metavar='N',
                    help='input batch size for training (default: 256)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=30, metavar='N',
                    help='number of epochs to train (default: 30)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
                    help='SGD momentum (default: 0.9)')
parser.add_argument('--weight_decay', type=float, default=0.0005, metavar='W',
                    help='SGD weight decay (default: 0.0005)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
                    help='how many batches to wait before logging training status')
parser.add_argument('--vis_path', type=str, default="visualizations/color6", metavar='S',
                    help='path to save your visualization figures')
args = parser.parse_args()
#use_cuda = False
use_cuda = not args.no_cuda and torch.cuda.is_available()
print(use_cuda)

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")


torch.backends.cudnn.benchmark = True if use_cuda else False

print(device)
kwargs = {'num_workers': 6, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        #nn.init.xavier_uniform(m.weight.batch_data)
    #elif isinstance(m, nn.BatchNorm2d):
    #    m.weight.batch_data.fill_(1)
    #    m.bias.batch_data.zero_()    
    #elif isinstance(m, nn.BatchNorm1d):
    #    m.weight.batch_data.fill_(1)
    #    m.bias.batch_data.zero_()    

 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5, padding=2)
        #self.bn1 = nn.BatchNorm2d(32)
        self.prelu1 = nn.PReLU()
        self.conv2 = nn.Conv2d(32, 32, kernel_size=5, padding=2)
        #self.bn2 = nn.BatchNorm2d(32)
        self.prelu2 = nn.PReLU()
        self.conv3 = nn.Conv2d(32, 64, kernel_size=5, padding=2)
        #self.bn3 = nn.BatchNorm2d(64)
        self.prelu3 = nn.PReLU()
        self.conv4 = nn.Conv2d(64, 64, kernel_size=5, padding=2)
        #self.bn4 = nn.BatchNorm2d(64)
        self.prelu4 = nn.PReLU()
        self.conv5 = nn.Conv2d(64, 128, kernel_size=5, padding=2)
        #self.bn5 = nn.BatchNorm2d(128)
        self.prelu5 = nn.PReLU()
        self.conv6 = nn.Conv2d(128, 128, kernel_size=5, padding=2)
        #self.bn6 = nn.BatchNorm2d(128)
        self.prelu6 = nn.PReLU()
        self.fc1 = nn.Linear(1152, 2)
        #self.bn7 = nn.BatchNorm1d(2)
        self.prelu7 = nn.PReLU()
        self.fc2 = nn.Linear(2, 10)
        self.loss = nn.CrossEntropyLoss()
        #self.fc2 = LargeMarginSoftmaxLinear(2, 10, 2, 0)
        #self.fc2 = LargeMarginSoftmaxLinear(2, 10, 2, 0, use_cuda, device)

    def forward(self, x, y):
        x = self.prelu1(self.conv1(x))
        x = F.max_pool2d(self.prelu2(self.conv2(x)), 2, stride=2)
        x = self.prelu3(self.conv3(x))
        x = F.max_pool2d(self.prelu4(self.conv4(x)), 2, stride=2)
        x = self.prelu5(self.conv5(x))
        x = F.max_pool2d(self.prelu6(self.conv6(x)), 2, stride=2)

        #x = F.relu(self.bn1(self.conv1(x)))
        #x = F.max_pool2d(F.relu(self.bn2(self.conv2(x))), 2, stride=2)
        #x = F.relu(self.bn3(self.conv3(x)))
        #x = F.max_pool2d(F.relu(self.bn4(self.conv4(x))), 2, stride=2)
        #x = F.relu(self.bn5(self.conv5(x)))
        #x = F.max_pool2d(F.relu(self.bn6(self.conv6(x))), 2, stride=2)
        x = x.view(-1, 1152)
        #print(x.shape)
        features = self.fc1(x)
        x = self.prelu7(features)
        #x = F.relu(self.bn7(self.fc1(x)))
        #x = self.fc2(x, y)
        x = self.fc2(x)
        #print(x)
        return self.loss(x, y)

model = Net().to(device)
model.apply(weights_init)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

def visualize(features, labels, epoch, vis_path):
    plt.clf()
    for i in range(10):
        plt.scatter(features[(labels == i), 0], features[(labels == i), 1], s=3, c=tableau20[i], alpha=0.8) 
    plt.savefig(vis_path + "/epoch_" + str(epoch) + ".jpg")
    return

def train(epoch):
    model.train()
    #features = []
    #predictions = []        
    #labels = []
    for batch_idx, (batch_data, batch_labels) in enumerate(train_loader):

        batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
        #print(batch_data.dtype)
        #print(batch_labels.dtype)
        #print(batch_data.size())
        #print(batch_labels.size())
        optimizer.zero_grad()
        loss = model(batch_data, batch_labels)
        #loss = F.nll_loss(batch_scores, batch_labels)
        #print(loss)
        
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(batch_data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        #features.append(batch_features)
        #batch_predictions = batch_scores.max(1, keepdim=True)[1]
        #predictions.append(batch_predictions)
        #labels.append(batch_labels)
    
    #features = torch.cat(features, 0).data.to('cpu').numpy()
    #predictions = torch.cat(predictions, 0).data.cpu().numpy()
    #labels = torch.cat(labels, 0).data.to('cpu').numpy()
    #visualize(features, labels, epoch, args.vis_path)

def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_data, batch_labels in test_loader:
            batch_data, batch_labels = batch_data.to(device), batch_labels.to(device)
            batch_scores, batch_features = model(batch_data)
            test_loss += F.nll_loss(batch_scores, batch_labels, size_average=False).item() # sum up batch loss
            pred = batch_scores.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(batch_labels.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 
for epoch in range(1, args.epochs + 1):
    scheduler.step()
    #lr = []
    #for param_group in optimizer.param_groups:
    #    lr += [param_group['lr']]
    #print(lr)
    torch.cuda.synchronize()

    start = time.perf_counter()
    train(epoch)
    torch.cuda.synchronize()
    end = time.perf_counter()
    #print(end-start)
    print('Total time taken: {:.2f}s\n'.format(end-start))
    #test()

It takes 5 secs and mxnet is only taking 1s. (5x difference)

Are you still suggesting that this is small?

Just tried loading entire dataset onto GPU, the speed gain is only about 0.08s per iteration. It is still very much slower than mxnet for the larger network below (mxnet’s 1s vs pytorch’s 4.27s).

Seems like larger network is even worse (the new one above).

  1. I don’t think it’s Python overhead. The GPU utilization as reported by nvidia-smi is relatively high.
  2. Most of the computation is spent in cuDNN convolution. I’m seeing ~4 seconds per epoch which is ~18 ms per iteration (235 iterations per epoch). Convolution forward and backward takes ~13.5 ms PReLU takes ~3 ms (which seems high). The remaining 2ms are probably miscellaneous calls.

(Note that hierarchy of calls isn’t properly displayed; many of the “convolution” functions are the same call)

What’s the comparable mxnet code you’re running? I think MXNet also relies on cuDNN for convolutions.

EDIT: here’s the script I’m using for profiling: https://gist.github.com/colesbury/2b786128f4409dbd07eef44839f0b3ce