I use pytorch to train a CNN + RNN model. The input to the CNN is a face image, and the output is a depth image of the face (D)and a feature map of the face(T).
During the training period, the loss began to rise abnormally after a brief decrease. I don’t know what the reason is, and the Nan value of the training set and loss has been excluded. Even if I change the CNN to be the simplest stack of several convolutional layers, the same problem will occur.
Any guidance on what i am missing or doing wrong is much appreciated.
Thanks in advance.
First I train CNN only calculate the loss of the depth image of the face(D).
here is CNN definition.
input:
shape of input:(20625,3, 256, 256 )
exmple of input[0]:
array([[[0.7651416 , 0.75729847, 0.74553376],
[0.7651416 , 0.75729847, 0.74553376],
[0.7664488 , 0.75860566, 0.74684095],
...,
...,
[0.09063181, 0.09455337, 0.11023965],
[0.08148148, 0.08540305, 0.10108932],
[0.10239651, 0.10631809, 0.12200436]]], dtype=float32)
output:
shape of D :(20625, 1, 32, 32)
exmple of output[0]:
array([[[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)
class myCNN(nn.Module):
def __init__ (self):
super(myCNN,self).__init__()
self.resize_32 = nn.Upsample(size=32, mode='nearest')
self.resize_64 = nn.Upsample(size=64, mode='nearest')
self.cnn0=nn.Conv2d(in_channels=3,out_channels=64, kernel_size=3,stride=1,padding=1)
nn.init.xavier_normal(self.cnn0.weight)
self.bn0=nn.BatchNorm2d(64)
self.non_linearity0=nn.CELU(alpha=1.0, inplace=False)
self.cnn01=nn.Conv2d(in_channels=64,out_channels=128, kernel_size=3,stride=1,padding=1)
nn.init.xavier_normal(self.cnn01.weight)
self.bn01=nn.BatchNorm2d(128)
self.non_linearity01=nn.CELU(alpha=1.0, inplace=False)
#Block:
self.cnn1=nn.Conv2d(in_channels=128,out_channels=128, kernel_size=3,stride=1,padding=1)
nn.init.xavier_normal(self.cnn1.weight)
self.bn1=nn.BatchNorm2d(128)
self.non_linearity1=nn.CELU(alpha=1.0, inplace=False)
self.cnn2=nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3,stride=1,padding=1)
nn.init.xavier_normal(self.cnn2.weight)
self.bn2=nn.BatchNorm2d(196)
self.non_linearity2=nn.CELU(alpha=1.0, inplace=False)
self.cnn3=nn.Conv2d(in_channels=196, out_channels=128, kernel_size=3,stride=1,padding=1)
nn.init.xavier_normal(self.cnn3.weight)
self.bn3=nn.BatchNorm2d(128)
self.non_linearity3=nn.CELU(alpha=1.0, inplace=False)
self.pool=nn.MaxPool2d(kernel_size=2)
#Feature map:
self.cnn4=nn.Conv2d(in_channels=384,out_channels=128, kernel_size=3,stride=1,padding=1)
self.cnn5=nn.Conv2d(in_channels=128,out_channels=3, kernel_size=3,stride=1,padding=1)
self.cnn6=nn.Conv2d(in_channels=3,out_channels=1, kernel_size=3,stride=1,padding=1)
#Depth map:
self.cnn7=nn.Conv2d(in_channels=384,out_channels=128, kernel_size=3,stride=1,padding=1)
self.cnn8=nn.Conv2d(in_channels=128,out_channels=64, kernel_size=3,stride=1,padding=1)
self.cnn9=nn.Conv2d(in_channels=64,out_channels=1, kernel_size=3,stride=1,padding=1)
def forward(self,x):
x=self.cnn0(x)
x=self.bn0(x)
x=self.non_linearity0(x)
#Block1
x=self.cnn01(x)
x=self.bn01(x)
x=self.non_linearity01(x)
x=self.cnn2(x)
x=self.bn2(x)
x=self.non_linearity2(x)
x=self.cnn3(x)
x=self.bn3(x)
x=self.non_linearity3(x)
x=self.pool(x)
X1=self.resize_64(x)
#Block2
x=self.cnn1(x)
x=self.bn1(x)
x=self.non_linearity1(x)
x=self.cnn2(x)
x=self.bn2(x)
x=self.non_linearity2(x)
x=self.cnn3(x)
x=self.bn3(x)
x=self.non_linearity3(x)
x=self.pool(x)
X2=x
#Block3:
x=self.cnn1(x)
x=self.bn1(x)
x=self.non_linearity1(x)
x=self.cnn2(x)
x=self.bn2(x)
x=self.non_linearity2(x)
x=self.cnn3(x)
x=self.bn3(x)
x=self.non_linearity3(x)
x=self.pool(x)
X3=self.resize_64(x)
X=torch.cat((X1,X2,X3),1)
#Feature map:
T=self.cnn4(X)
T=self.cnn5(T)
T=self.cnn6(T)
T=self.resize_32(T)
#Depth map:
D=self.cnn7(X)
D=self.cnn8(D)
D=self.cnn9(D)
D=self.resize_32(D)
return D,T
for epoch in range(n_epoch):
# loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader_D, 0):
images, labels_D = data
input_images = images.cuda()
input_labels_D = labels_D.cuda()
# handle NaN:
if torch.norm((input_images != input_images).float()) == 0 and torch.norm( (input_labels_D != input_labels_D).float()) == 0:
# training step
optimizer.zero_grad()
outputs_D, _ = model(input_images)
# handle NaN:
if (torch.norm((outputs_D != outputs_D).float()) == 0):
loss = criterion(outputs_D, input_labels_D)
loss.backward()
optimizer.step()
# compute statistics
running_loss += loss.item()
total += labels_D.size(0)
print('epoch:{}, batch:{},loss: {},total:{}'.format(epoch + 1, i + 1, running_loss / max(total, 1),total))
print('Epoch finished')
print('Finished Training')
epoch:1, batch:254,imgae_id:1265-1270,loss: 453.21149039553455,total:1255
epoch:1, batch:255,imgae_id:1270-1275,loss: 451.8062790204608,total:1260
epoch:1, batch:256,imgae_id:1275-1280,loss: 461.1012440119807,total:1265
epoch:1, batch:257,imgae_id:1280-1285,loss: 460.12750893390086,total:1270
epoch:1, batch:258,imgae_id:1285-1290,loss: 459.3293710312189,total:1275
epoch:1, batch:259,imgae_id:1290-1295,loss: 457.93095548599956,total:1280
epoch:1, batch:260,imgae_id:1295-1300,loss: 456.35654145326134,total:1285
epoch:1, batch:261,imgae_id:1300-1305,loss: 462.5303374105646,total:1290
epoch:1, batch:262,imgae_id:1305-1310,loss: 461.4984705906577,total:1295
epoch:1, batch:263,imgae_id:1310-1315,loss: 460.92106004421527,total:1300
epoch:1, batch:264,imgae_id:1315-1320,loss: 460.67283501022166,total:1305
epoch:1, batch:265,imgae_id:1320-1325,loss: 460.3284591602005,total:1310
epoch:1, batch:266,imgae_id:1325-1330,loss: 470.0556516277473,total:1315
epoch:1, batch:267,imgae_id:1330-1335,loss: 481.83144106286943,total:1320
epoch:1, batch:268,NaN
epoch:1, batch:269,NaN
epoch:1, batch:270,imgae_id:1345-1350,loss: 513.4538283135756,total:1325
epoch:1, batch:271,imgae_id:1350-1355,loss: 720.239857154502,total:1330
epoch:1, batch:272,imgae_id:1355-1360,loss: 1126.8173295996162,total:1335
epoch:1, batch:273,imgae_id:1360-1365,loss: 6633.459429116036,total:1340
epoch:1, batch:274,imgae_id:1365-1370,loss: 24338.859208189955,total:1345
epoch:1, batch:275,imgae_id:1370-1375,loss: 73173.49750741888,total:1350
epoch:1, batch:276,imgae_id:1375-1380,loss: 409249.8167048085,total:1355
epoch:1, batch:277,imgae_id:1380-1385,loss: 932172.7982610408,total:1360
epoch:1, batch:278,imgae_id:1385-1390,loss: 3669391.2832490955,total:1365
epoch:1, batch:279,imgae_id:1390-1395,loss: 6217432.944259135,total:1370
epoch:1, batch:280,imgae_id:1395-1400,loss: 20190537.353916377,total:1375
epoch:1, batch:281,imgae_id:1400-1405,loss: 42477758.88524277,total:1380
epoch:1, batch:282,imgae_id:1405-1410,loss: 77050900.88493502,total:1385
epoch:1, batch:283,imgae_id:1410-1415,loss: 141472458.03570864,total:1390
epoch:1, batch:284,imgae_id:1415-1420,loss: 250212103.33880645,total:1395
epoch:1, batch:285,imgae_id:1420-1425,loss: 487474627.4725964,total:1400
epoch:1, batch:286,imgae_id:1425-1430,loss: 737760512.7271423,total:1405
epoch:1, batch:287,imgae_id:1430-1435,loss: 1172190678.057897,total:1410
epoch:1, batch:288,imgae_id:1435-1440,loss: 1873252101.0640528,total:1415
epoch:1, batch:289,imgae_id:1440-1445,loss: 3269063222.263123,total:1420
epoch:1, batch:290,imgae_id:1445-1450,loss: 5399095726.617288,total:1425
epoch:1, batch:291,imgae_id:1450-1455,loss: 7465755323.791353,total:1430
sorry for my poor english.