Evaluator returns NaN?

Hello. I found some weird situation on ignite’s evaluator.

I have a pytorch dataloader and well-trained model.
when I try to extract output from model manually like below, it works well…

logfile = open('logs/manual.log', 'w')

dataloader_iter = iter(dataloader)

for i in range(500):
    x, t = next(dataloader_iter)
    y = model(x)
    
    logfile.write('{0}th iteration...\n'.format(i))
    logfile.write('    x: {0}\n'.format(x))
    logfile.write('    O: {0}\n'.format(y))
    logfile.write('    o: {0}\n'.format(torch.argmax(y, dim=1)))
    logfile.write('    t: {0}\n\n'.format(t))
    
logfile.close()

Then the logfile shows(focus on direct output from model(O):

1th iteration...
    <I skiped other things...>
O: tensor([[-2.8684e+00, -2.4779e+00,  1.4409e+00,  4.2450e+00,  1.3716e+00,
          3.1618e+00,  2.5039e+00, -1.3585e-03, -3.0291e+00, -3.1306e+00],
        [ 3.1552e+00,  1.8194e+00, -3.0561e-01, -1.8409e+00, -2.6845e-01,
         -3.2020e+00, -3.0243e+00, -7.1415e-01,  2.4644e+00,  1.6819e+00],
        [ 8.4511e-01, -4.4308e-01,  1.1573e+00,  1.2453e-01,  5.5392e-01,
          6.7758e-03, -1.0871e+00,  2.9498e-01, -7.7099e-01, -2.3612e-01],
        [ 4.2768e+00,  1.0600e+00,  2.5328e-01, -8.9427e-01, -1.2167e+00,
         -2.5544e+00, -4.6379e+00, -9.9491e-01,  4.3370e+00,  3.3613e-01],
        [-1.5431e+00, -2.1083e+00,  2.6824e+00,  1.7168e+00,  2.7112e+00,
          9.7829e-01,  2.7005e+00,  6.7293e-01, -3.5220e+00, -3.2395e+00],
        [-1.7089e+00, -6.2691e-02, -6.5357e-01,  1.6353e+00,  3.9782e-01,
          1.3109e+00,  3.1719e-01,  2.2915e-01, -2.3173e+00,  1.5951e+00],
        [ 3.6005e-01,  6.5493e+00, -1.8465e+00,  1.6237e-01, -2.9798e+00,
         -1.6791e+00, -3.1212e+00, -6.6102e-01, -1.4621e+00,  4.4547e+00],
        [-4.2739e-02, -3.4736e+00,  2.4689e+00,  4.2563e-01,  2.9417e+00,
         -7.3102e-01,  2.5900e+00, -1.5313e-01, -6.9163e-01, -1.9232e+00],
        [ 5.4870e-01, -3.8533e+00,  1.4458e+00,  2.0452e+00,  1.9034e+00,
          2.4421e+00, -1.2948e+00,  6.9678e-01,  3.8147e-01, -3.9030e+00],
        [ 1.9710e+00,  2.5695e+00,  6.1807e-01, -3.4376e-01,  7.6892e-02,
         -2.4107e+00, -8.0755e-01, -6.9068e-01, -6.7928e-01,  3.1438e-01],
        [-2.5117e-01, -4.4562e+00,  2.7807e+00,  2.3242e+00,  3.1733e+00,
          2.7852e+00,  5.5874e-01, -1.3494e-01,  1.5951e-01, -5.1482e+00],
        [-7.8764e-01,  1.3296e+00, -1.2178e+00,  5.5745e-01, -1.5662e+00,
         -5.0908e-01, -1.1492e+00,  4.3546e-01, -1.2073e+00,  3.8453e+00],
        [-8.3729e-01, -2.7771e+00,  1.8575e+00,  2.3420e+00,  1.8659e+00,
          2.1749e+00,  1.6091e+00, -4.0015e-01, -1.2970e+00, -3.6298e+00],
        [-1.2611e+00, -2.7069e-02,  9.4248e-02,  8.1586e-01,  1.2401e+00,
          1.4915e+00, -1.5474e+00,  2.9358e+00, -2.9514e+00, -7.6922e-01],
        [ 2.9973e+00,  1.5974e+00,  3.1030e-01, -5.9278e-01, -1.3065e+00,
         -3.6284e+00, -1.7095e+00, -2.5131e+00,  2.5173e+00,  2.9853e+00],
        [ 2.8719e+00,  2.5325e-01,  1.6667e+00, -3.4530e-01,  1.8475e-01,
         -2.7671e+00,  1.9819e+00, -3.5765e+00,  2.2073e+00, -1.5772e+00],
        [-1.3318e+00, -1.1345e+00,  9.9565e-01,  3.0902e+00, -7.6009e-02,
          3.6095e+00, -1.6780e+00,  9.3089e-01, -1.7266e+00, -2.3228e+00],
        [-3.5282e-01, -2.2260e-01, -5.0407e-01, -8.8783e-02,  1.5501e+00,
         -6.1404e-01, -1.8794e+00,  2.5478e+00, -2.1103e+00,  2.4213e+00],
        [ 3.3746e+00,  1.5732e+00, -9.2370e-01, -7.0857e-01, -1.9817e+00,
         -2.9782e+00, -3.5029e+00, -2.0559e+00,  3.9011e+00,  3.1880e+00],
        [-2.7136e+00, -2.1364e+00,  2.3538e+00,  2.2541e+00,  3.3178e+00,
          2.8204e+00,  2.5427e+00,  1.8494e+00, -4.9199e+00, -4.0796e+00]],
       device='cuda:0', grad_fn=<GatherBackward>)

but like below, when I used evaluator, I found something strange:

logfile2 = open('logs/evaluator.log', 'w')

evaluator = create_evaluator(model)

@evaluator.on(Events.ITERATION_COMPLETED)
def log_inference_output(engine):
    y_pred, y = engine.state.output
    
    logfile2.write('{0}th iteration...\n'.format(engine.state.iteration))
    logfile2.write('    o: {0}\n'.format(y_pred))
    logfile2.write('    t: {0}\n\n'.format(y))
    
evaluator.run(dataloader)

logfile2.close()
1th iteration...
    o: tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]], device='cuda:0')
    t: tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6],
       device='cuda:0')

Why this situation happens? I tried to figure out this for several days but I can’t get the answer.

My environment is:

  • pytorch: 1.7.0
  • ignite: 0.4.2

and I’m using 8 GPUs with dataparalled model.

Any helps will be appreciated. thanks.

@FruitVinegar interesting situation.

What does evaluator behind is (link) :

model.eval()
with torch.no_grad():
    x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
    y_pred = model(x)
    # ...

Can you check if you also have NaNs with the following code :

logfile = open('logs/manual.log', 'w')

dataloader_iter = iter(dataloader)

model.eval()
with torch.no_grad():
    for i in range(500):
        x, t = next(dataloader_iter)
        y = model(x)
    
        logfile.write('{0}th iteration...\n'.format(i))
        logfile.write('    x: {0}\n'.format(x))
        logfile.write('    O: {0}\n'.format(y))
        logfile.write('    o: {0}\n'.format(torch.argmax(y, dim=1)))
        logfile.write('    t: {0}\n\n'.format(t))
    
logfile.close()

Can you detail how your model is defined such that I could reproduce the issue from my side ?

Oh my. Why I didn’t put that context manager in first code snippet.
Well, after wrapping my first code with

with torch.no_grad():

(which is same with your suggestion) also generates NaN values.

Then the problem is (probably) my model does not works well in evaluation phase… which means the problem is on my model.
I got important clue so I’ll try more to figure out the problem is. Thanks for the good point…
(Last my question was also not ignite-related problem(lambdas with output_transforms) But I saw my traces on your API documents. Hope my works helps users :wink: )

Anyway, my model is customized-resnet architecture, which enables quantize-aware training for both activation and weights with 8 bits.
This uses third-party library and components, but I think the problem is not on them.(I hope)
This is the third-party library what I used: Brevitas: Pytorch library for quantization-aware training.
and I referenced the original resnet architecture from: SOURCE CODE FOR TORCHVISION.MODELS.RESNET

import typing as ty
#import logging

import torch
import torch.nn as nn
import brevitas
import brevitas.nn as qnn
import torchvision

from architectures.components.PACT import PACTReLU

#logging.basicConfig(filename='QAT_8b_PACT_gradClip.log', level=logging.DEBUG)

__all__ = ['ResNet_QAT_8b', 'resnet18_QAT_8b']


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return qnn.QuantConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                           padding=dilation, groups=groups, bias=False, dilation=dilation, weight_bit_width=8)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return qnn.QuantConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False, weight_bit_width=8)


def make_PACT_relu():
    relu = qnn.QuantReLU(bit_width=8)
    relu.act_impl = PACTReLU()
    return relu



class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self,
                 inplanes,
                 planes, 
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None):
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = make_PACT_relu()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride
        
        
    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
    
    
        
class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion = 4
    
    def __init__(self,
                 inplanes,
                 planes, 
                 stride=1,
                 downsample=None,
                 groups=1,
                 base_width=64,
                 dilation=1,
                 norm_layer=None):
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = make_PACT_relu()
        self.downsample = downsample
        self.stride = stride
        
        
    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
    
    
    
class ResNet_QAT_8b(nn.Module):
    def __init__(self,
                 block,
                 layers,
                 num_classes=1000,
                 zero_init_residual=False,
                 groups=1, 
                 width_per_group=64,
                 replace_stride_with_dilation=None,
                 norm_layer=None):
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = qnn.QuantConv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = make_PACT_relu()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d): # qnn.QuantConv2d includes nn.Conv2d inside.
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)
                    
                    
    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))

        return nn.Sequential(*layers)
    
    
    def _forward_impl(self, x):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    
    def forward(self, x):
        return self._forward_impl(x)

    
    
def _resnet_QAT_8b(block,
                   layers, 
                   **kwargs):
    model = ResNet_QAT_8b(block, layers, **kwargs)
    print(model)
    return model


def resnet18_QAT_8b(config, 
                    **kwargs):
    # config is fake argument.
    return _resnet_QAT_8b(BasicBlock,
                          [2, 2, 2, 2],
                          **kwargs)

and definition for PACT components:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np



class PACTClip(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.save_for_backward(x, alpha)
        return torch.clamp(x, 0, alpha.data)
    
    @staticmethod
    def backward(ctx, dy):
        x, alpha = ctx.saved_tensors
        
        dx = dy.clone()
        dx[x < 0] = 0
        dx[x > alpha] = 0
        
        dalpha = dy.clone()
        dalpha[x <= alpha] = 0
        
        return dx, torch.sum(dalpha)
    
    

class PACTReLU(nn.Module):
    def __init__(self, alpha=6.0):
        super().__init__()
        self.alpha = nn.Parameter(torch.tensor(alpha))
        
    def forward(self, x):
        return PACTClip.apply(x, self.alpha)

I’ll post If I got the solution. thank you.

1 Like

I’ve just tested your code to train resnet18 QAT 8 bits on CIFAR10 with DP on 2 GPUs and I also have nans with loss and y_pred. Then tested on a single gpu without DP and the model looks like learning something.

Maybe, you can try to wrap the model with torch DDP such that you could use all 8 GPUs:

    import ignite.distributed as idist
    # ...

    model = resnet18_QAT_8b(...)
    lrank = idist.get_local_rank()
    model = model.to(idist.device())
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank,], find_unused_parameters=True)

and you have to use distributed data sampling…

In my case with CIFAR10, resnet18 QAT 8bits model trained with nccl DDP does not throw nans and achieves test accuracy: 0.8453 on 24 epochs (~ same as single GPU same model’s test accuracy = 0.846).

Link to the code with DDP, cifar10, resnet qat : ignite/examples/contrib/cifar10_qat at cifar10-qat · vfdev-5/ignite · GitHub

Wow. Thanks for your sincerity.

Now I can specify where is the problem part, but I still don’t understand why this happens.

I tried 2 tests, first:

dataloader_iter = iter(dataloader)

def forward_and_log1(model, x, t, no_grad=False):
    if no_grad:
        logfile = open('logs/no_grad_enabled.log', 'a')
    else:
        logfile = open('logs/no_grad_disabled.log', 'a')

    
    y = model(x)
    
    logfile.write('    x: {0}\n'.format(x))
    logfile.write('    O: {0}\n'.format(y))
    logfile.write('    o: {0}\n'.format(torch.argmax(y, dim=1)))
    logfile.write('    t: {0}\n\n'.format(t))
    
    logfile.close()

for i in range(100):
    x, t = next(dataloader_iter)
    
    forward_and_log1(model, x, t, no_grad=False)
    with torch.no_grad():
        forward_and_log1(model, x, t, no_grad=True)

There was no problems(NaN values) in both 2 log files(no_grad_enabled.log, no_grad_disabled.log)
Then the problem is not with torch.no_grad()

second:

dataloader_iter = iter(dataloader)

def forward_and_log2(model, x, t, is_train_mode=True):
    if is_train_mode:
        logfile = open('logs/mode_train.log', 'a')
    else:
        logfile = open('logs/mode_eval.log', 'a')
        
    y = model(x)
    
    logfile.write('    x: {0}\n'.format(x))
    logfile.write('    O: {0}\n'.format(y))
    logfile.write('    o: {0}\n'.format(torch.argmax(y, dim=1)))
    logfile.write('    t: {0}\n\n'.format(t))
    
    logfile.close()

    
for i in range(100):
    x, t = next(dataloader_iter)
    
    model_GPU.train()
    forward_and_log2(model, x, t, is_train_mode=True)
    model_GPU.eval()
    forward_and_log2(model, x, t, is_train_mode=False)

Here I found NaNs occured in eval mode logfile.
Problem was on eval mode!
But I still confused why this situation happens, moreover with your result(get NaNs with DP but not with DDP). maybe I found some BUGs with DP?
I found similar topic with my situation, but the solution was not clear.

Anyway, DDP gives the answer to me. Thank you.

Yes, I think it is (probably) known issue with DP and quantization.

Hey @FruitVinegar , I think this time your question will also help other users:
I plan to create an example of QAT using the code you shared here to train a resnet18 on CIFAR10.
If you’d like to take a look and comment out, it could be helpful : Added Cifar10 QAT example by vfdev-5 · Pull Request #1556 · pytorch/ignite · GitHub
Thanks !

Sure. you can modify my code to several networks or datasets.