I’m running a code on graph convolutional networks. When i running a simple network, amp works well. But when i change to run a more complex one, its loss become NAN after training several seconds. I didn’t change any other files. How could i fix it?
Below are the details
PyTorch: 1.6.0
torchvision: 0.7.0
cuda : 10.2
cudnn: 7.5
GPU: 2080ti
This is the model file.
Blockquote
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.cuda.amp import autocast
class GraphAttentionLayer_st(nn.Module):
def __init__(self, in_features, out_features, A_type = 1, window_size=3, dilation=1):
super(GraphAttentionLayer_st, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.A_type = A_type
self.window_size = window_size
self.inter_channels = out_features//4
self.W1 = nn.Conv2d(in_features, self.inter_channels, kernel_size=1) # [N, C, T, V] -> [N, C_inter, T, V]
self.W2 = nn.Conv2d(in_features, self.inter_channels, kernel_size=1) # [N, C, T, V] -> [N, C_inter, T, V]
self.transform = nn.Conv2d(in_features, out_features, kernel_size=1)
self.out_conv = nn.Conv3d(out_features, out_features, kernel_size=(1, self.window_size, dilation))
self.out_bn = nn.BatchNorm2d(out_features)
def forward(self, inp):
N, C, T, V = inp.size()
x1 = self.W1(inp).permute(0, 3, 1, 2).contiguous().view(N, V, self.inter_channels * T)
x2 = self.W2(inp).view(N, self.inter_channels*T, 25) # [N, C, T, V] -> [N, C_inter*T, V]
attention = torch.matmul(x1, x2)/(self.inter_channels*T) # [N, V, V]
attention = F.softmax(attention, dim=-2)
x = self.transform(inp) # [N, C_out, T, V]
out = torch.einsum('nctv,nuv->nctu',x, attention)
out = out.view(N, self.out_features, -1, self.window_size, V//self.window_size)
out = self.out_conv(out).squeeze()
out = self.out_bn(out)
return out
def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
class unit_tcn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
super(unit_tcn, self).__init__()
pad = int((kernel_size - 1) / 2)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
stride=(stride, 1))
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU() # not used
def forward(self, x):
x = self.bn(self.conv(x))
return x
class unit_gcn(nn.Module):
def __init__(self, in_channels, out_channels, nheads=0):
super(unit_gcn, self).__init__()
self.nheads=nheads
if nheads>0:
nhid=out_channels//nheads
self.attentions = nn.ModuleList([GraphAttentionLayer_st(in_channels, nhid, A_type=head, window_size=1, dilation=1) for head in range(nheads)])
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
N, C, T, V = x.size()
out=0
if self.nheads>0:
out = torch.cat([att(x) for att in self.attentions], dim=1)
out = self.bn(out)
return self.relu(out)
class TCN_GCN_unit(nn.Module):
def __init__(self, in_channels, out_channels, nheads=0, stride=1):
super(TCN_GCN_unit, self).__init__()
self.nheads = nheads
self.tcn1 = unit_tcn(out_channels, out_channels, stride=stride)
self.gcn1 = unit_gcn(in_channels, out_channels, nheads=nheads)
self.relu = nn.ReLU()
def forward(self, x):
x = self.tcn1(self.gcn1(x))
return self.relu(x)
class Model(nn.Module):
def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, in_channels=3):
super(Model, self).__init__()
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
nheads = 4
self.l1 = TCN_GCN_unit(3, 64, nheads=nheads)
self.l2 = TCN_GCN_unit(64, 128, nheads=nheads, stride=2)
self.l3 = TCN_GCN_unit(128, 256, nheads=nheads, stride=2)
self.fc = nn.Linear(256, num_class)
@autocast()
def forward(self, x):
N, C, T, V, M = x.size()
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T) # (M V C) bn
x = self.data_bn(x)
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
# N*M,C,T,V
c_new = x.size(1)
x = x.view(N, M, c_new, -1)
x = x.mean(3).mean(1)
return self.fc(x)