In some condition, Conv3D outputs are very different according to Apex amp opt-level on V100 GPU

While training 3D CNN with Apex amp, I found very weird result (accuracy drops around 6%) specifically on pytorch version 1.4.0 , V100 gpu, and amp opt-level O1 and O2.
On pytorch version 1.1.0, it worked correctly both on P40, V100 GPU with opt-level O2.
On pytorch version 1.4.0, it worked correctly on P40 GPU with opt-level O2 (maybe also OK with O0).
On pytorch version 1.4.0, it worked correctly on V100 GPU with opt-level O0.

So I compare every single module’s output in my model with opt-level O0 and O2 on cuda9.2, pytorch version 1.4.0, V100 GPU.
And I found that Conv3D with kernel_size (3,1,1) outputs show large difference according to opt-level.
So I wrote a code snippet that can observe the behavior as below.
It seems that it happens when channel size is large.

import torch
import torch.nn as nn
import apex
from apex import amp

def init_model(in_channel=8, out_channel=4, kernel_size=(3, 1, 1), padding=(1, 0, 0), opt_level='O0'):
    model = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, bias=False)
    model = model.cuda()
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters,
                                    lr=0.1,
                                    )

    model, optimizer = amp.initialize(model, optimizer,
                                      opt_level=opt_level,
                                      keep_batchnorm_fp32=None if opt_level=='O1' else True,
                                      loss_scale=None
                                      )
    return model, optimizer

def compare_opt_level(in_channel, out_channel, kernel_size, padding, opt1='O0', opt2='O2'):
    input = torch.randn(size=(4, in_channel, 8, 16, 16)).cuda()

    n0, _ = init_model(in_channel, out_channel, kernel_size=kernel_size, padding=padding, opt_level=opt1)
    n2, _ = init_model(in_channel, out_channel, kernel_size=kernel_size, padding=padding, opt_level=opt2)
    init_param = torch.randn_like(n0.weight)
    #init_param = nn.init.kaiming_normal(n0.weight, mode='fan_out')
    n0.weight = nn.Parameter(init_param)
    n2.weight = nn.Parameter(init_param.half())

    v0 = n0(input)
    v2 = n2(input)

    print('/////////////////////////////')
    print('Compare {}/{}, in/out channel {}/{}, kernel size {}'.format(
        opt1, opt2, in_channel, out_channel, kernel_size))
    print(torch.sqrt((v0 - v2) ** 2).mean(dim=(1, 2, 3, 4)))

if __name__=='__main__':
    torch.random.manual_seed(1)
    ic, oc = 8, 4
    compare_opt_level(ic, oc, kernel_size=(3, 1, 1), padding=(1, 0, 0))
    compare_opt_level(ic, oc, kernel_size=(3, 3, 3), padding=(1, 1, 1))
    compare_opt_level(ic, oc, kernel_size=(1, 1, 1), padding=(0, 0, 0))
    compare_opt_level(ic, oc, kernel_size=(1, 3, 3), padding=(0, 1, 1))
    
    ic, oc = 512, 128
    compare_opt_level(ic, oc, kernel_size=(3, 1, 1), padding=(1, 0, 0), opt1='O0', opt2='O2')
    compare_opt_level(ic, oc, kernel_size=(3, 3, 3), padding=(1, 1, 1), opt1='O0', opt2='O2')
    compare_opt_level(ic, oc, kernel_size=(1, 1, 1), padding=(0, 0, 0), opt1='O0', opt2='O2')
    compare_opt_level(ic, oc, kernel_size=(1, 3, 3), padding=(0, 1, 1), opt1='O0', opt2='O2')

    ic, oc = 512, 128
    compare_opt_level(ic, oc, kernel_size=(3, 1, 1), padding=(1, 0, 0), opt1='O0', opt2='O1')
    compare_opt_level(ic, oc, kernel_size=(3, 3, 3), padding=(1, 1, 1), opt1='O0', opt2='O1')
    compare_opt_level(ic, oc, kernel_size=(1, 1, 1), padding=(0, 0, 0), opt1='O0', opt2='O1')
    compare_opt_level(ic, oc, kernel_size=(1, 3, 3), padding=(0, 1, 1), opt1='O0', opt2='O1')

The printed MSE between O0 and O2 on V100 are,

Compare O0/O2, in/out channel 8/4, kernel size (3, 1, 1)
tensor([0.0016, 0.0017, 0.0016, 0.0016], device='cuda:0',
       grad_fn=<MeanBackward1>)

Compare O0/O2, in/out channel 8/4, kernel size (3, 3, 3)
tensor([0.0042, 0.0043, 0.0042, 0.0042], device='cuda:0',
       grad_fn=<MeanBackward1>)

Compare O0/O2, in/out channel 8/4, kernel size (1, 1, 1)
tensor([0.0007, 0.0007, 0.0007, 0.0007], device='cuda:0',
       grad_fn=<MeanBackward1>)

Compare O0/O2, in/out channel 8/4, kernel size (1, 3, 3)
tensor([0.0021, 0.0021, 0.0022, 0.0021], device='cuda:0',
       grad_fn=<MeanBackward1>)

Compare O0/O2, in/out channel 512/128, kernel size (3, 1, 1)
tensor([34.6549, 34.6763, 34.7861, 34.7793], device='cuda:0',
       grad_fn=<MeanBackward1>)

Compare O0/O2, in/out channel 512/128, kernel size (3, 3, 3)
tensor([0.0303, 0.0303, 0.0303, 0.0303], device='cuda:0',
       grad_fn=<MeanBackward1>)

Compare O0/O2, in/out channel 512/128, kernel size (1, 1, 1)
tensor([0.0064, 0.0064, 0.0064, 0.0064], device='cuda:0',
       grad_fn=<MeanBackward1>)

Compare O0/O2, in/out channel 512/128, kernel size (1, 3, 3)
tensor([0.0183, 0.0183, 0.0183, 0.0183], device='cuda:0',
       grad_fn=<MeanBackward1>)

You can see that the MSE is exceptionally large when kernel size is (3, 1, 1) with large channels

Unlike V100, the MSE is not so big on P40,

Compare O0/O2, in/out channel 8/4, kernel size (3, 1, 1)
tensor([0.0014, 0.0014, 0.0014, 0.0014], device='cuda:0',
       grad_fn=<MeanBackward2>)

Compare O0/O2, in/out channel 8/4, kernel size (3, 3, 3)
tensor([0.0038, 0.0037, 0.0037, 0.0037], device='cuda:0',
       grad_fn=<MeanBackward2>)

Compare O0/O2, in/out channel 8/4, kernel size (1, 1, 1)
tensor([0.0009, 0.0009, 0.0009, 0.0009], device='cuda:0',
       grad_fn=<MeanBackward2>)

Compare O0/O2, in/out channel 8/4, kernel size (1, 3, 3)
tensor([0.0023, 0.0023, 0.0023, 0.0023], device='cuda:0',
       grad_fn=<MeanBackward2>)

Compare O0/O2, in/out channel 512/128, kernel size (3, 1, 1)
tensor([0.0106, 0.0106, 0.0106, 0.0106], device='cuda:0',
       grad_fn=<MeanBackward2>)

Compare O0/O2, in/out channel 512/128, kernel size (3, 3, 3)
tensor([0.0305, 0.0303, 0.0303, 0.0304], device='cuda:0',
       grad_fn=<MeanBackward2>)

Compare O0/O2, in/out channel 512/128, kernel size (1, 1, 1)
tensor([0.0064, 0.0064, 0.0064, 0.0064], device='cuda:0',
       grad_fn=<MeanBackward2>)

Compare O0/O2, in/out channel 512/128, kernel size (1, 3, 3)
tensor([0.0184, 0.0184, 0.0183, 0.0183], device='cuda:0',
       grad_fn=<MeanBackward2>)

When I compare O0 and O1, MSEs are all zero.
However the training was also not working properly with O1 on V100, so there might be some other issues.

I recently found that there is torch.cuda.amp in nightly version, but haven’t tested on it.
I cannot move to torch.cuda.amp because it seems that O2 level is not provided yet.

1 Like

Calling amp.initialize multiple times with different opt_levels is not supported, to you should split the test code into different runs using the same data (load the random tensors from disc and reinitialize the model with the same state_dict).
Also, you shouldn’t call half() on any parameter, model, or data.

Thanks @ptrblck, I tested as you suggested (on V100).
Sorry that the modified codes below are not well organized.

I first save the state_dict and input in the disk as follows,

import torch
import torch.nn as nn
import apex
from apex import amp

def init_model(in_channel=8, out_channel=4, kernel_size=(3, 1, 1), padding=(1, 0, 0), opt_level='O0'):
    model = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, bias=False)
    model = model.cuda()
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters,
                                    lr=0.1,
                                    )

    model, optimizer = amp.initialize(model, optimizer,
                                      opt_level=opt_level,
                                      keep_batchnorm_fp32=None if opt_level=='O1' else True,
                                      loss_scale=None
                                      )
    return model, optimizer

def compare_opt_level(in_channel, out_channel, kernel_size, padding, opt1='O0', opt2='O2'):
    input = torch.randn(size=(4, in_channel, 8, 16, 16)).cuda()

    n0, _ = init_model(in_channel, out_channel, kernel_size=kernel_size, padding=padding, opt_level=opt1)
    init_param = torch.randn_like(n0.weight)
    n0.weight = nn.Parameter(init_param)

    v0 = n0(input)

    save_dict = {'input': input.cpu(), 'state_dict':n0.state_dict()}
    torch.save(save_dict, 'ckpt_{}_{}_{}.pt'.format('_'.join(map(str, kernel_size)), in_channel, out_channel))

if __name__=='__main__':
    torch.random.manual_seed(1)
    ic, oc = 8, 4
    compare_opt_level(ic, oc, kernel_size=(3, 1, 1), padding=(1, 0, 0))
    compare_opt_level(ic, oc, kernel_size=(3, 3, 3), padding=(1, 1, 1))
    compare_opt_level(ic, oc, kernel_size=(1, 1, 1), padding=(0, 0, 0))
    compare_opt_level(ic, oc, kernel_size=(1, 3, 3), padding=(0, 1, 1))
    
    ic, oc = 512, 128
    compare_opt_level(ic, oc, kernel_size=(3, 1, 1), padding=(1, 0, 0), opt1='O0', opt2='O2')
    compare_opt_level(ic, oc, kernel_size=(3, 3, 3), padding=(1, 1, 1), opt1='O0', opt2='O2')
    compare_opt_level(ic, oc, kernel_size=(1, 1, 1), padding=(0, 0, 0), opt1='O0', opt2='O2')
    compare_opt_level(ic, oc, kernel_size=(1, 3, 3), padding=(0, 1, 1), opt1='O0', opt2='O2')

Then, save the output of the model with different opt-level separately with following code,

import argparse
import torch
import torch.nn as nn
import apex
from apex import amp

def init_model(in_channel=8, out_channel=4, kernel_size=(3, 1, 1), padding=(1, 0, 0), opt_level='O0', ckpt=None):
    model = nn.Conv3d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, bias=False)
    model.load_state_dict(ckpt['state_dict'])
    model = model.cuda()
    parameters = model.parameters()
    optimizer = torch.optim.SGD(parameters,
                                    lr=0.1,
                                    )

    model, optimizer = amp.initialize(model, optimizer,
                                      opt_level=opt_level,
                                      keep_batchnorm_fp32=None if opt_level=='O1' else True,
                                      loss_scale=None
                                      )
    return model, optimizer

def get_result(in_channel, out_channel, kernel_size, padding, opt_level='O0'):
    checkpoint = torch.load('ckpt_{}_{}_{}.pt'.format('_'.join(map(str, kernel_size)), in_channel, out_channel))
    input = checkpoint['input'].cuda()

    n0, _ = init_model(in_channel, out_channel, kernel_size=kernel_size, padding=padding, opt_level=opt_level,
                       ckpt = checkpoint)
    v0 = n0(input)

    return v0


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--opt_level', type=str)
    args = parser.parse_args()

    result_dict = {}
    ic, oc = 8, 4
    ks = [[3, 1, 1], [3, 3, 3], [1, 1, 1], [1, 3, 3]]
    for k in ks:
        r = get_result(ic, oc, kernel_size=k, padding=[(i-1)//2 for i in k], opt_level=args.opt_level)
        result_dict['{}_{}_{}'.format('_'.join(map(str, k)), ic, oc)] = r

    ic, oc = 512, 128
    for k in ks:
        r = get_result(ic, oc, kernel_size=k, padding=[(i-1)//2 for i in k], opt_level=args.opt_level)
        result_dict['{}_{}_{}'.format('_'.join(map(str, k)), ic, oc)] = r

    torch.save(result_dict, 'result_{}'.format(args.opt_level))

Then compare the result files as follows,

import torch

if __name__ == '__main__':
    opt1 = 'O0'
    opt2 = 'O2'
    dict_O0 = torch.load('./result_'+opt1)
    dict_O2 = torch.load('./result_'+opt2)

    for k in sorted(dict_O0.keys()):

        v0 = dict_O0[k]
        v2 = dict_O2[k]
        k1, k2, k3, in_channel, out_channel = k.split('_')
        kernel_size = (k1, k2, k3)
        print('/////////////////////////////')
        print('Compare {}/{}, in/out channel {}/{}, kernel size {}'.format(
            opt1, opt2, in_channel, out_channel, kernel_size))
        print(torch.sqrt((v0 - v2) ** 2).mean(dim=(1, 2, 3, 4)))

And I got the similar result as before.

Compare O0/O2, in/out channel 512/128, kernel size ('1', '1', '1')
tensor([0.0064, 0.0064, 0.0064, 0.0064], device='cuda:0',
       grad_fn=<MeanBackward2>)
/////////////////////////////
Compare O0/O2, in/out channel 8/4, kernel size ('1', '1', '1')
tensor([0.0008, 0.0009, 0.0009, 0.0009], device='cuda:0',
       grad_fn=<MeanBackward2>)
/////////////////////////////
Compare O0/O2, in/out channel 512/128, kernel size ('1', '3', '3')
tensor([0.0184, 0.0184, 0.0184, 0.0183], device='cuda:0',
       grad_fn=<MeanBackward2>)
/////////////////////////////
Compare O0/O2, in/out channel 8/4, kernel size ('1', '3', '3')
tensor([0.0023, 0.0023, 0.0023, 0.0023], device='cuda:0',
       grad_fn=<MeanBackward2>)
/////////////////////////////
Compare O0/O2, in/out channel 512/128, kernel size ('3', '1', '1')
tensor([34.9032, 34.8360, 34.9023, 34.8901], device='cuda:0',
       grad_fn=<MeanBackward2>)
/////////////////////////////
Compare O0/O2, in/out channel 8/4, kernel size ('3', '1', '1')
tensor([0.0016, 0.0016, 0.0016, 0.0016], device='cuda:0',
       grad_fn=<MeanBackward2>)
/////////////////////////////
Compare O0/O2, in/out channel 512/128, kernel size ('3', '3', '3')
tensor([0.0303, 0.0303, 0.0304, 0.0303], device='cuda:0',
       grad_fn=<MeanBackward2>)
/////////////////////////////
Compare O0/O2, in/out channel 8/4, kernel size ('3', '3', '3')
tensor([0.0043, 0.0042, 0.0043, 0.0042], device='cuda:0',
       grad_fn=<MeanBackward2>)

By the way, if I shouldn’t call half(), then how can I use some torch functions such as torch.randn() inside a module? I got an type error without half() in pytorch v1.1. Will it be different in v1.4?

Thanks for splitting up the code!
We’ll try to reproduce this issue locally.

I assume you are creating these tensors in the forward method? While O1 and native amp should work out of the box, I’m unsure if O2 would work and we would need to see a code snippet for it.

EDIT: I was able to reproduce and narrow down this issue to a faulty kernel in cudnn.
I’ve also tested the next release candidate and the issue is gone, so my only advice for you is not to use this particular convolution in mixed-precision training. :confused:

Is there any possible workaround? If I cannot used the kernel, then many recent 3D ResNets cannot be used with amp.
And what do you mean next release candidate? next release of cudnn? Do you have a schedule for the release?

About the example code snippet of using torch.randn(), there is an old github issue that I asked a couple of month ago. I found that you answered.


Please check if you can reproduce with the modified code.

Yes, the next cudnn release. I don’t know exactly when it’s released.

That sounds bad. Do you have links to public repositories, which implement these models and could post them here, please?

I answered in the issue and cannot reproduce the error with your provided code.
Could you check, if you are still seeing this issue?

So, only way is wait until the next cudnn is released (probably 7.7? or 8.0?) and pytorch supports the new release. Maybe I need to find some other model without using (3,1,1) kernel.

There are several models that uses conv3D with (3, 1, 1) kernel.
Most famous one might be SlowFast network from facebook and there are many follow up models.

R(2+1)D also uses the 3D conv with (3, 1, 1) kernel.

As for the last issue, It seems it related to the pytorch version. It works without the error in pytorch-v1.4.
Now problem is I cannot use amp in pytorch-v1.4 because of the above issue.

1 Like

I tested 8 possible kernel_size from [1, 1, 1] to [3, 3, 3] and found that only [3, 1, 1] has the issue.
So, as a workaround, I considered to use Conv3d with [1, 1, 3] kernel and transpose the input/output tensor before/after the convolution inside the forward method as below.

class Conv3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
              padding_mode='zeros'):
        super(Conv3d, self).__init__()
        print('Use Conv3d Module')
        self.permute = False
        ks, sr, pad, dl = kernel_size, stride, padding, dilation
        if isinstance(kernel_size, (list, tuple)):
            if kernel_size == (3, 1, 1) or kernel_size == [3, 1, 1]:
                self.permute = True
                ks = ks[::-1]
                if isinstance(sr, (list, tuple)): sr = sr[::-1]
                if isinstance(pad, (list, tuple)): pad = pad[::-1]
                if isinstance(dl, (list, tuple)):  dl = dl[::-1]
        self.conv = nn.Conv3d(in_channels, out_channels, ks, stride=sr, padding=pad, dilation=dl,
                  groups=groups, bias=bias, padding_mode=padding_mode)

    def forward(self, x):
        if self.permute:
            return self.conv(x.permute(0,1,4,3,2)).permute(0,1,4,3,2)
        else:
            return self.conv(x)

It seems the MSE is reduced as normal in the previous test code.
@ptrblck, do you think it can be a possible workaround? Will the permutation increase the computation time a lot?
If the operation is functionally same as the original Conv3d, I’m considering testing it in my model, though the computation time will increase somewhat.

1 Like

You approach should be valid and you could compare the current performance hit with a run without cudnn. If your current permutation is still faster than disabling cudnn, it’s a valid workaround.

I compared several possible combinations of cuda, pytorch and gpus on my 3D ResNet model.

Cuda pytorch GPU Etc Top1 Time(s)
9.1 1.1 V100 68.7 215732
9.2 1.4 P40 68.1 305547
9.2 1.4 V100 61.7 155272
9.2 1.4 V100 Custom Conv 68.2 162994
9.2 1.4 V100 No cudnn 62.9 160497
10.2 1.5 V100 63.0 162676
10.2 1.5 V100 Custom Conv 67.6 186566

When I use the custom Conv3d above, the top1 accuracies are increased compared to nn.Conv3d.
However, it is still 0.5~1 % lower than pytoch1.1. I’m not sure it is caused by random initialization or some other issues. The wall-time decreased compared to cuda 9.1 (V100 machines are not exactly same but each has same number of CPU cores and GPUs).

When cudnn is turned off, the top1 accuracy is still lower. In order to turn off the cudnn, I make cudnn.benchmark=False. I’m not sure it is right way.

1 Like

No, you should use torch.backends.cudnn.enabled = False.
Your code will only disable the benchmark mode, which tries to find the fastest kernel.

The accuracy drop can be explained by the wrong answers of the current cudnn kernels, so stick to the custom conv approach or disable cudnn (you would have to profile it again unfortunately).

1 Like

Hi, @ptrblck
Would you mind giving some hints that why this issue happens on V100, while no problem on P100, etc.?

Thank you.

Different GPU architectures can use different kernel implementations.
Based on

a particular kernel for this architecture was broken in cudnn 7.6.5.32 and fixed in cudnn>=8.0.

1 Like

thank you very much for your quick reply.

What does the wrong answers mean? :grinning:

I meant numerical mismatches, e.g. a 1.0 is expected at a particular output index, while the kernel returns 1.1.

1 Like