NAN loss after training several seconds

I’m running a code on graph convolutional networks. When i running a simple network, amp works well. But when i change to run a more complex one, its loss become NAN after training several seconds. I didn’t change any other files. How could i fix it?
Below are the details

PyTorch: 1.6.0
torchvision: 0.7.0
cuda : 10.2
cudnn: 7.5
GPU: 2080ti
This is the model file.

Blockquote

import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.cuda.amp import autocast


class GraphAttentionLayer_st(nn.Module):

    def __init__(self, in_features, out_features, A_type = 1, window_size=3, dilation=1):
        super(GraphAttentionLayer_st, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.A_type = A_type
        self.window_size = window_size

        self.inter_channels = out_features//4
        self.W1 = nn.Conv2d(in_features, self.inter_channels, kernel_size=1)   # [N, C, T, V] -> [N, C_inter, T, V]
        self.W2 = nn.Conv2d(in_features, self.inter_channels, kernel_size=1)   # [N, C, T, V] -> [N, C_inter, T, V]
        self.transform = nn.Conv2d(in_features, out_features, kernel_size=1)
        self.out_conv = nn.Conv3d(out_features, out_features, kernel_size=(1, self.window_size, dilation))
        self.out_bn = nn.BatchNorm2d(out_features)

    def forward(self, inp):
        N, C, T, V = inp.size()
        x1 = self.W1(inp).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_channels * T)
        x2 = self.W2(inp).view(N, self.inter_channels*T, 25)   #  [N, C, T, V] -> [N, C_inter*T, V]
        attention = torch.matmul(x1, x2)/(self.inter_channels*T) # [N, V, V]
        attention = F.softmax(attention, dim=-2) 
        x = self.transform(inp)       # [N, C_out, T, V]
        out = torch.einsum('nctv,nuv->nctu',x, attention)
        out = out.view(N, self.out_features, -1, self.window_size, V//self.window_size)
        out = self.out_conv(out).squeeze()
        out = self.out_bn(out)
        return out

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class unit_tcn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
        super(unit_tcn, self).__init__()
        pad = int((kernel_size - 1) / 2)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
                              stride=(stride, 1))

        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()  # not used

    def forward(self, x):
        x = self.bn(self.conv(x))
        return x


class unit_gcn(nn.Module):
    def __init__(self, in_channels, out_channels, nheads=0):
        super(unit_gcn, self).__init__()
        self.nheads=nheads
        if nheads>0:
            nhid=out_channels//nheads
            self.attentions = nn.ModuleList([GraphAttentionLayer_st(in_channels, nhid, A_type=head, window_size=1, dilation=1) for head in range(nheads)])
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
    def forward(self, x):
        N, C, T, V = x.size()
        out=0
        if self.nheads>0:
            out = torch.cat([att(x) for att in self.attentions], dim=1)
        out = self.bn(out)
        return self.relu(out)


class TCN_GCN_unit(nn.Module):
    def __init__(self, in_channels, out_channels, nheads=0, stride=1):
        super(TCN_GCN_unit, self).__init__()
        self.nheads = nheads
        self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
        self.gcn1 = unit_gcn(in_channels, out_channels, nheads=nheads)
        self.relu = nn.ReLU()
    def forward(self, x):

        x = self.tcn1(self.gcn1(x))
        return self.relu(x)


class Model(nn.Module):
    def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, in_channels=3):
        super(Model, self).__init__()

            
        self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
        nheads = 4
        self.l1 = TCN_GCN_unit(3, 64, nheads=nheads)   
        self.l2 = TCN_GCN_unit(64, 128, nheads=nheads,  stride=2)
        self.l3 = TCN_GCN_unit(128, 256, nheads=nheads, stride=2)
        self.fc = nn.Linear(256, num_class)
    @autocast()
    def forward(self, x):
        N, C, T, V, M = x.size()
        
        x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)   # (M V C) bn
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)

        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        
        # N*M,C,T,V
        c_new = x.size(1)
        x = x.view(N, M, c_new, -1)
        x = x.mean(3).mean(1)
        return self.fc(x)

I also tried to run this file with APEX. But the loss became NaN soon. I observed the loss scalar factor became 1e-20. But it didn’t make effect to avoid the problem.

Could you check in the forward method which layer outputs the first Inf or NaN values?

I checked the data flow after you replied. I found it became NaN between the second and the third layer, i.e.between the relu function in self.l2 and unit_gcn layer in self.l3. However, it confused me because there is no operation between them. The mode only assigned the output of self.relu in self.l2 to x, and then x became NaN. It’s really confusing.

In addition, i noticed that the PyTorch told me the gradients became NaN several iterations (about 40 with batch size 32) before inputs became NaN. I figure that the abnormal loss causes the weights to become NaN. So the main reason may be amp not working well.

If the initial loss is too high and the GradScaler uses a scaling value, which is also high, the gradients might overflow, which is expected. The scaler will then skip this update and reduce the scaling factor.
However, since the updates are skipped, the forward pass should never return invalid values.

If I understand it correctly, you are seeing the NaN values here:

class TCN_GCN_unit(nn.Module):
    def __init__(self, in_channels, out_channels, nheads=0, stride=1):
        super(TCN_GCN_unit, self).__init__()
        self.nheads = nheads
        self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
        self.gcn1 = unit_gcn(in_channels, out_channels, nheads=nheads)
        self.relu = nn.ReLU()
    def forward(self, x):

        x = self.tcn1(self.gcn1(x))
        # x is valid
        return self.relu(x) # returns NaNs?
class TCN_GCN_unit(nn.Module):
    def __init__(self, in_channels, out_channels, nheads=0, stride=1):
        super(TCN_GCN_unit, self).__init__()
        self.nheads = nheads
        self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
        self.gcn1 = unit_gcn(in_channels, out_channels, nheads=nheads)
        self.relu = nn.ReLU()
    def forward(self, x):

        x = self.tcn1(self.gcn1(x))
        x = self.relu(x)
        # I didn't observe NaN here.
        return  x
x = self.l2(x)
# Actually i found NaN here. 
x = self.l3(x)

i made a little modification in the code. I change the relu function as a single operation x = self.relu(x), and then return x. In this way i can observe the input change accurately. But it behaves like above.

So you don’t see the NaN inside the module (i.e. after the relu), but outside of it for the “same” tensor. Is my understanding correct?

Yes, your understanding is correct