Loss began to rise abnormally after a brief decrease

I use pytorch to train a CNN + RNN model. The input to the CNN is a face image, and the output is a depth image of the face (D)and a feature map of the face(T).

During the training period, the loss began to rise abnormally after a brief decrease. I don’t know what the reason is, and the Nan value of the training set and loss has been excluded. Even if I change the CNN to be the simplest stack of several convolutional layers, the same problem will occur.
Any guidance on what i am missing or doing wrong is much appreciated.
Thanks in advance.

First I train CNN only calculate the loss of the depth image of the face(D).
here is CNN definition.

input:
shape of input:(20625,3, 256, 256 )
exmple of input[0]:
array([[[0.7651416 , 0.75729847, 0.74553376],
        [0.7651416 , 0.75729847, 0.74553376],
        [0.7664488 , 0.75860566, 0.74684095],
        ...,
        ...,
        [0.09063181, 0.09455337, 0.11023965],
        [0.08148148, 0.08540305, 0.10108932],
        [0.10239651, 0.10631809, 0.12200436]]], dtype=float32)

output:
shape of  D :(20625, 1, 32, 32)
exmple of output[0]:
array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]], dtype=float32)

class myCNN(nn.Module):
  
    def __init__ (self):
        super(myCNN,self).__init__()
            
        self.resize_32 = nn.Upsample(size=32, mode='nearest')
        self.resize_64 = nn.Upsample(size=64, mode='nearest')

        self.cnn0=nn.Conv2d(in_channels=3,out_channels=64, kernel_size=3,stride=1,padding=1)
        nn.init.xavier_normal(self.cnn0.weight) 
        self.bn0=nn.BatchNorm2d(64)
        self.non_linearity0=nn.CELU(alpha=1.0, inplace=False)
        
        self.cnn01=nn.Conv2d(in_channels=64,out_channels=128, kernel_size=3,stride=1,padding=1)
        nn.init.xavier_normal(self.cnn01.weight) 
        self.bn01=nn.BatchNorm2d(128)
        self.non_linearity01=nn.CELU(alpha=1.0, inplace=False)

        #Block:
        self.cnn1=nn.Conv2d(in_channels=128,out_channels=128, kernel_size=3,stride=1,padding=1)
        nn.init.xavier_normal(self.cnn1.weight) 
        self.bn1=nn.BatchNorm2d(128)
        self.non_linearity1=nn.CELU(alpha=1.0, inplace=False)
        
        self.cnn2=nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3,stride=1,padding=1)
        nn.init.xavier_normal(self.cnn2.weight)
        self.bn2=nn.BatchNorm2d(196)
        self.non_linearity2=nn.CELU(alpha=1.0, inplace=False)
        
        self.cnn3=nn.Conv2d(in_channels=196, out_channels=128, kernel_size=3,stride=1,padding=1)
        nn.init.xavier_normal(self.cnn3.weight)
        self.bn3=nn.BatchNorm2d(128)
        self.non_linearity3=nn.CELU(alpha=1.0, inplace=False)
        
        self.pool=nn.MaxPool2d(kernel_size=2)
        
        #Feature map:
        self.cnn4=nn.Conv2d(in_channels=384,out_channels=128, kernel_size=3,stride=1,padding=1)
        self.cnn5=nn.Conv2d(in_channels=128,out_channels=3, kernel_size=3,stride=1,padding=1)
        self.cnn6=nn.Conv2d(in_channels=3,out_channels=1, kernel_size=3,stride=1,padding=1)
        
        #Depth map:
        self.cnn7=nn.Conv2d(in_channels=384,out_channels=128, kernel_size=3,stride=1,padding=1)
        self.cnn8=nn.Conv2d(in_channels=128,out_channels=64, kernel_size=3,stride=1,padding=1)
        self.cnn9=nn.Conv2d(in_channels=64,out_channels=1, kernel_size=3,stride=1,padding=1)
        
        
    def forward(self,x):
        
        x=self.cnn0(x)
        x=self.bn0(x)
        x=self.non_linearity0(x)
        
        #Block1
        x=self.cnn01(x)
        x=self.bn01(x)
        x=self.non_linearity01(x)
        x=self.cnn2(x)
        x=self.bn2(x)
        x=self.non_linearity2(x)
        x=self.cnn3(x)
        x=self.bn3(x)
        x=self.non_linearity3(x)
        x=self.pool(x)
        
        X1=self.resize_64(x)
        
        #Block2
        x=self.cnn1(x)
        x=self.bn1(x)
        x=self.non_linearity1(x)
        x=self.cnn2(x)
        x=self.bn2(x)
        x=self.non_linearity2(x)
        x=self.cnn3(x)
        x=self.bn3(x)
        x=self.non_linearity3(x)
        x=self.pool(x)
        
        X2=x
        
        #Block3:
        x=self.cnn1(x)
        x=self.bn1(x)
        x=self.non_linearity1(x)
        x=self.cnn2(x)
        x=self.bn2(x)
        x=self.non_linearity2(x)
        x=self.cnn3(x)
        x=self.bn3(x)
        x=self.non_linearity3(x)
        x=self.pool(x)
        
        X3=self.resize_64(x)
        
        X=torch.cat((X1,X2,X3),1)
        
        #Feature map:
        T=self.cnn4(X)
        T=self.cnn5(T)
        T=self.cnn6(T)
        T=self.resize_32(T)
        
        #Depth map:
        D=self.cnn7(X)
        D=self.cnn8(D)
        D=self.cnn9(D)
        D=self.resize_32(D)
        
        return D,T
for epoch in range(n_epoch):
    # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader_D, 0):
    
        images, labels_D = data

        input_images = images.cuda()
        input_labels_D = labels_D.cuda()
        # handle NaN:
        if torch.norm((input_images != input_images).float()) == 0 and torch.norm( (input_labels_D != input_labels_D).float()) == 0:

            # training step
            optimizer.zero_grad()

            outputs_D, _ = model(input_images)
           
            # handle NaN:
            if (torch.norm((outputs_D != outputs_D).float()) == 0):
               
                loss = criterion(outputs_D, input_labels_D)

                loss.backward()
                optimizer.step()

                # compute statistics
                running_loss += loss.item()
                total += labels_D.size(0)

                print('epoch:{}, batch:{},loss: {},total:{}'.format(epoch + 1, i + 1, running_loss / max(total, 1),total))
    print('Epoch finished')
print('Finished Training')
epoch:1, batch:254,imgae_id:1265-1270,loss: 453.21149039553455,total:1255
epoch:1, batch:255,imgae_id:1270-1275,loss: 451.8062790204608,total:1260
epoch:1, batch:256,imgae_id:1275-1280,loss: 461.1012440119807,total:1265
epoch:1, batch:257,imgae_id:1280-1285,loss: 460.12750893390086,total:1270
epoch:1, batch:258,imgae_id:1285-1290,loss: 459.3293710312189,total:1275
epoch:1, batch:259,imgae_id:1290-1295,loss: 457.93095548599956,total:1280
epoch:1, batch:260,imgae_id:1295-1300,loss: 456.35654145326134,total:1285
epoch:1, batch:261,imgae_id:1300-1305,loss: 462.5303374105646,total:1290
epoch:1, batch:262,imgae_id:1305-1310,loss: 461.4984705906577,total:1295
epoch:1, batch:263,imgae_id:1310-1315,loss: 460.92106004421527,total:1300
epoch:1, batch:264,imgae_id:1315-1320,loss: 460.67283501022166,total:1305
epoch:1, batch:265,imgae_id:1320-1325,loss: 460.3284591602005,total:1310
epoch:1, batch:266,imgae_id:1325-1330,loss: 470.0556516277473,total:1315
epoch:1, batch:267,imgae_id:1330-1335,loss: 481.83144106286943,total:1320
epoch:1, batch:268,NaN
epoch:1, batch:269,NaN
epoch:1, batch:270,imgae_id:1345-1350,loss: 513.4538283135756,total:1325
epoch:1, batch:271,imgae_id:1350-1355,loss: 720.239857154502,total:1330
epoch:1, batch:272,imgae_id:1355-1360,loss: 1126.8173295996162,total:1335
epoch:1, batch:273,imgae_id:1360-1365,loss: 6633.459429116036,total:1340
epoch:1, batch:274,imgae_id:1365-1370,loss: 24338.859208189955,total:1345
epoch:1, batch:275,imgae_id:1370-1375,loss: 73173.49750741888,total:1350
epoch:1, batch:276,imgae_id:1375-1380,loss: 409249.8167048085,total:1355
epoch:1, batch:277,imgae_id:1380-1385,loss: 932172.7982610408,total:1360
epoch:1, batch:278,imgae_id:1385-1390,loss: 3669391.2832490955,total:1365
epoch:1, batch:279,imgae_id:1390-1395,loss: 6217432.944259135,total:1370
epoch:1, batch:280,imgae_id:1395-1400,loss: 20190537.353916377,total:1375
epoch:1, batch:281,imgae_id:1400-1405,loss: 42477758.88524277,total:1380
epoch:1, batch:282,imgae_id:1405-1410,loss: 77050900.88493502,total:1385
epoch:1, batch:283,imgae_id:1410-1415,loss: 141472458.03570864,total:1390
epoch:1, batch:284,imgae_id:1415-1420,loss: 250212103.33880645,total:1395
epoch:1, batch:285,imgae_id:1420-1425,loss: 487474627.4725964,total:1400
epoch:1, batch:286,imgae_id:1425-1430,loss: 737760512.7271423,total:1405
epoch:1, batch:287,imgae_id:1430-1435,loss: 1172190678.057897,total:1410
epoch:1, batch:288,imgae_id:1435-1440,loss: 1873252101.0640528,total:1415
epoch:1, batch:289,imgae_id:1440-1445,loss: 3269063222.263123,total:1420
epoch:1, batch:290,imgae_id:1445-1450,loss: 5399095726.617288,total:1425
epoch:1, batch:291,imgae_id:1450-1455,loss: 7465755323.791353,total:1430

sorry for my poor english.

Most probably diverging. Try smaller learning rate.