[SOLVED] Titan V on PyTorch 0.3.0, CUDA 9.0, CUDNN 7.0 is much slower than 1080 Ti

Look like Titan V is not so fast compared to 1080Ti.
Could you test full timing of forward-backward pass for your both card. Here is the script which I prepare, which test 3 mentioned architecture in both FP32 and FP16 (I’m not sure if FP16 is used in right way, anybody could check?). I would like to compare Tesla V100 vs Titan V vs 1080Ti (in that case only FP32).

import torch
from torchvision.models import vgg16,densenet121,resnet152
from time import time
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim
from torch.autograd import Variable
import torchvision.models as models
torch.backends.cudnn.benchmark=True
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

print('cuda version=', torch.version.cuda)
print('cudnn version=', torch.backends.cudnn.version())

for arch in ['densenet121', 'vgg16', 'resnet152']:
    model   = models.__dict__[arch]().cuda()
    loss   = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), 0.001,
                                        momentum=0.9,
                                        weight_decay=1e-5)
    durations = []
    num_runs = 100

    for i in range(num_runs + 1):
        x = torch.rand(16, 3, 224, 224)
        x_var = torch.autograd.Variable(x).cuda()
        target = Variable(torch.LongTensor(16).fill_(1).cuda())
        torch.cuda.synchronize()
        t1 = time()
        out = model(x_var)
        err = loss(out, target)
        err.backward()
        optimizer.step()
        torch.cuda.synchronize()
        t2 = time()

        # treat the initial run as warm up and don't count
        if i > 0:
            durations.append(t2 - t1)

    print('{} FP 32 avg over {} runs: {} ms'.format(arch, len(durations), sum(durations) / len(durations) * 1000)) 

    model   = models.__dict__[arch]().cuda().half()
    loss   = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), 0.001,
                                        momentum=0.9,
                                        weight_decay=1e-5)
    durations = []
    num_runs = 100

    for i in range(num_runs + 1):
        x = torch.rand(16, 3, 224, 224)
        x_var = torch.autograd.Variable(x).cuda().half()
        target = Variable(torch.LongTensor(16).fill_(1).cuda())
        torch.cuda.synchronize()
        t1 = time()
        out = model(x_var)
        err = loss(out, target)
        err.backward()
        optimizer.step()
        torch.cuda.synchronize()
        t2 = time()

        # treat the initial run as warm up and don't count
        if i > 0:
            durations.append(t2 - t1)

    print('{} FP 16 avg over {} runs: {} ms'.format(arch, len(durations), sum(durations) / len(durations) * 1000)) 

I have tested it V100 on AWS p3 instance and here are the results:

('cuda version=', '9.0.176')
('cudnn version=', 7003)
densenet121 FP 32 avg over 100 runs: 66.9194436073 ms
densenet121 FP 16 avg over 100 runs: 59.2770695686 ms
vgg16 FP 32 avg over 100 runs: 85.537352562 ms
vgg16 FP 16 avg over 100 runs: 56.5851545334 ms
resnet152 FP 32 avg over 100 runs: 134.783308506 ms
resnet152 FP 16 avg over 100 runs: 87.0116662979 ms