3D CNN overfittting issue

Hi,
I am trying to retrain a 3D CNN model from a research article and I run into overfitting issues even upon implementing data augmentation on the fly to avoid overfitting.

I can see that my model learns and then starts to oscillate along the same loss numbers.

Any suggestions on how to improve or how I should proceed in preventing the model from overfitting will be of great help. For further reference, I am posting my code below and the loss. Thanks

What I tried to avoid overfitting:

  • image augmentation,
  • regularization.
  • Dropout layers
  • Increasing the batch size
  • Shuffling the data-loader

Training dataset size: approx 358 images+augmentation(every epoch)

class T_LEAP(torch.nn.Module):
    """T_LEAP ARCHITECTURE"""

    def __init__(self):
        super(T_LEAP, self).__init__()
        self.encoder = torch.nn.Sequential(
                                          torch.nn.Conv3d(3, 64, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=64),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2)),
                                          
                                          
                                           
                                          torch.nn.Conv3d(64, 128, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),                                         
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(2,2,2),stride=(2,2,2)),
                                          
                                           
                                          torch.nn.Conv3d(128, 256, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2)),
                                         
                                           
                                          torch.nn.Conv3d(256, 512, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=512),
                                          torch.nn.ReLU(inplace=True),
                                          
                                          
                                          )
        
        self.decoder = torch.nn.Sequential(
                                          torch.nn.ConvTranspose3d(512, 256, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1), output_padding=(0,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256), 
                                          torch.nn.ReLU(inplace=True),
                                     
                                          torch.nn.Conv3d(256, 256, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256),
                                          torch.nn.ReLU(inplace=True),
        
                                          torch.nn.ConvTranspose3d(256,128, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1),output_padding=(0,1,1)), 
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
                                          
                                          torch.nn.Conv3d(128, 128, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
                                          
                                          torch.nn.ConvTranspose3d(128,17, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1),output_padding=(0,1,1)),
                                          torch.nn.BatchNorm3d(num_features=17),
                                          torch.nn.ReLU(inplace=True),
                                          )           
               
    def forward(self, image):
        """PUTTING THE MODEL TOGETHER"""
        encoder = self.encoder(image)
        decoder = self.decoder(encoder)
        out     = torch.nn.Softmax(dim=1)(decoder)
        return out


model = T_LEAP()
if torch.cuda.is_available():
  input = torch.rand(1,3, 2, 200, 200).cuda()
  model = model.cuda()
  print(model(input).shape)
else:
  summary(model,
          input_size=(3, 2, 200, 200),
          batch_size=1
          )

Training loss:
image

other stuff i checked:

Training loop:

DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS     = 35
lr         = 1e-3
model      = T_LEAP()
optimizer  = torch.optim.Adam(model.parameters(), lr=lr)
model = model.to(DEVICE)
criterion  = torch.nn.MSELoss()

train_loss = []
for epoch in range(EPOCHS): 
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # Get[inputs, labels]
        inputs, labels = data["images"], data["heatmaps"]

        inputs = inputs.permute(0,2,1,3,4).to(DEVICE,dtype=torch.float)
        labels = labels.unsqueeze(0).permute(1,4,0,2,3).to(DEVICE,dtype=torch.float)
        
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print("loss:",running_loss/len(train_loader))
    train_loss.append(running_loss)
        
print('Finished Training')

How do you estimate overfitting? I don’t see any validation loss here.

This was one of the runs that i made on the data and the model performs like this

It doesn’t look like overfitting since model is unable to generalize at all, according to val loss.
Your weights are rather small, usually this happens when model decides that best choice here is to provide constant output (usually zero). I’d look into scaling/normalizing data, simplifying model and descreasing Adam lr further.

1 Like

I have scaled the data using ToTensor() and have varied the learning rate of the optimizer from 1e-3 to 1e5 and the train loss decreases at higher learning rates 1e-5 but doesn’t learn much as compared to a learning rate of 1e-3.

I also noticed that when I print my loss.backward() while training and I get none as the output. Would this mean that my model is not learning at all?

To me it seems strange you are using softmax + mse loss.
You are computing some sort of probabilities but then you use a cartesian distance? Strange

Are you using the same dataset or a different one? You have convolutions with 512 filters which leads to a huge amount of a parameters for a 3D nn. Be sure you have the proper amount of parameters given the size of your dataset as 3D convolutions tend to overfit.

Despite the scale of the plot is not proper, it seems even the training is not learning much if it can converge in 10 epochs.
Lastly it would be nice if you use any other metric to measure “overfitting”. Some sort of accuracy would be nice as you can measure if it better than by chance.

Lastly, just random suggestion, are you using a validation set extracted from the same type of data? If you use 2 different datasets (rather than splitting one into train and val) you may find this.

1 Like

Hey Juan,
Thanks for getting back and sharing your suggestions.

  • To me it seems strange you are using softmax + mse loss.
    You are computing some sort of probabilities but then you use a cartesian distance? - The reason for using softmax plus mse loss is a direct implementation of the paper ([2104.08029] T-LEAP: occlusion-robust pose estimation of walking cows using temporal information). It is for a pose estimation problem so that I can have confidence maps as the output.

  • Lastly, just random suggestion, are you using a validation set extracted from the same type of data? If you use 2 different datasets (rather than splitting one into train and Val) you may find this. - I am using a personal dataset and splitting the same data into an 80:20 train and validation dataset. I am using different videos as one dataset

  • Despite the scale of the plot is not proper, it seems even the training is not learning much if it can converge in 10 epochs. - I have trained for 80 epochs and the loss remains on the same scale. However, I realized that when I print loss.backward() i get None value as output which also means that the network is not training.

Well I’m not sure about what’s the return of loss.backward()
It can be None if the method “backward” returns nothing . The requirement is loss is not NaN.
Soo if you are using a custom dataset, how big is it compared to the original one? You may want to reduce the amount of convolutions / filters you are using.

If your loss or any of the inner computations were NaN or Inf, wou values of the weights wo. Would also be NaN. Basically, when a NaN falls in the process it spreads. I think it’s just regular overfitting.

Ofc I’m neither sure about the range of values your loss should be reaching but 10e-3 seems already small. I would bet it reached those values within the first epoch.

1 Like
  • Soo if you are using a custom dataset, how big is it compared to the original one- The paper says they have used samples of four consecutive frames resulting in 1059 samples(assuming 1059*4 frames). In my case, I am using 387 images + data augmentation per epoch. With data augmentations, I once generate 1059 frames and it still got stuck. Like you said the network might need more data or need to be reduced in size
  • Ofc I’m neither sure about the range of values your loss should be reaching but 10e-3 seems already small. I would bet it reached those values within the first epoch. - For 1e-3 it reaches the values in the first epoch and then the loss oscillates around the next values for a long time and diminishes very slowly
  • One of my last resort ideas was to eventually increase the training data to see if it helps in any way. However, using augmentation does the trick

Augmentation helps but it’s not magic.
I would count the parameters and try with a smaller network. But still, try to measure the performance with any other mertric than the loss. All the networks (moreover the 3D CNN ones) may overfit.

I’m more curious about your gradient sizes and model length. Maybe your model needs some split connections? Print a sum of your gradients at each layer.

Could you suggest any other metric? Thanks

Model summary

I tried to do loss.register_hook(lambda grad: print(grad)):
How do you compute the gradients at each layer?

image

Well given the fact you are stimating joints you just need to have a look at other papers like openpose

So you can basically compute accuracy (are the predicted joints matched with the proper ones), the joint-wise error, recall etcetera…

1 Like

Hey Johnson, by split connections do you mean skip connections? I tried to add skip connections and test them.

My output looks like for the below architecture:
image

class T_LEAP_SKIP(torch.nn.Module):
    """T_LEAP ARCHITECTURE"""

    def __init__(self):
        super(T_LEAP_SKIP, self).__init__()
        
        self.upconv1 = torch.nn.Sequential(torch.nn.Conv3d(3, 64, kernel_size=(3,3,3), 
                                                stride=(1,1,1), padding=(1,1,1)),
                                                torch.nn.BatchNorm3d(num_features=64),
                                                torch.nn.ReLU(inplace=True),
                                                )
        self.pool1   = torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2))

        self.upconv2 = torch.nn.Sequential(torch.nn.Conv3d(64, 128,kernel_size=(3,3,3), 
                                                    stride=(1,1,1), padding=(1,1,1)), 
                                                                            
                                      torch.nn.BatchNorm3d(num_features=128),
                                      torch.nn.ReLU(inplace=True),
                                      
                                    )
        self.pool2   = torch.nn.MaxPool3d(kernel_size=(2,2,2),stride=(2,2,2))

        self.upconv3 = torch.nn.Sequential(torch.nn.Conv3d(128, 256,
                                                           kernel_size=(3,3,3), 
                                                           stride=(1,1,1), 
                                                           padding=(1,1,1)), 
                                                                            
                                      torch.nn.BatchNorm3d(num_features=256),
                                      torch.nn.ReLU(inplace=True),
                                      
                                    )
        
        self.pool3   = torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2))

        self.upconv4 = torch.nn.Sequential(torch.nn.Conv3d(256, 
                                                           512,
                                                           kernel_size=(1,3,3), 
                                                           stride=(1,1,1),
                                                           padding=(1,1,1)), 
                                                                            
                                      torch.nn.BatchNorm3d(num_features=512),
                                      torch.nn.ReLU(inplace=True),
                               
                                    )
        #-----------------------------------------------------------------------
        #1st conv-transpose layer 
        self.convt1 = torch.nn.Sequential(torch.nn.ConvTranspose3d(512, 
                                                                   256, 
                                                                   kernel_size=(1,3,3),
                                                                   stride = (1,2,2), 
                                                                   padding=(1,1,1),
                                                                   output_padding=(0,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256), 
                                          torch.nn.ReLU(inplace=True)
                                          )
        self.down1  = torch.nn.Sequential(torch.nn.Conv3d(512, 256, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256),
                                          torch.nn.ReLU(inplace=True))
        
        
        #2nd conv-transpose layer 
        self.convt2 = torch.nn.Sequential(torch.nn.ConvTranspose3d(256, 
                                                                   128, 
                                                                   kernel_size=(1,3,3),
                                                                   stride = (1,2,2), 
                                                                   padding=(1,1,1),
                                                                   output_padding=(0,1,1)),
                                          torch.nn.BatchNorm3d(num_features=128), 
                                          torch.nn.ReLU(inplace=True)
                                          )
        
        self.down2  = torch.nn.Sequential(torch.nn.Conv3d(256,128, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True)) 

        #Final conv-transpose layer                                   
        self.convt3 = torch.nn.Sequential(torch.nn.ConvTranspose3d(128, 
                                                                   17, 
                                                                   kernel_size=(1,3,3),
                                                                   stride = (1,2,2), 
                                                                   padding=(1,1,1),
                                                                   output_padding=(0,1,1)),
                                          torch.nn.BatchNorm3d(num_features=17), 
                                          torch.nn.ReLU(inplace=True)
                                          )

    def forward(self, image):
        """PUTTING THE MODEL TOGETHER"""
        
        # Encoder 
        block1 = self.upconv1(image)
        pool1  = self.pool1(block1)

        block2 = self.upconv2(pool1)
        pool2  = self.pool2(block2)

        block3 = self.upconv3(pool2)
        pool3  = self.pool3(block3)
      
        block4 = self.upconv4(pool3)


        #Decoder
        trans1 = self.convt1(block4)
        skip1   = torch.cat([trans1, block3], axis=1)
        up1     = self.down1(skip1)

        trans2 = self.convt2(up1)
        skip2  = torch.cat([trans2, block2[:,:,1:,:]], axis=1)   #trouble with matching depth
        up2    = self.down2(skip2)

        trans3 = self.convt3(up2)
        output = torch.nn.Softmax(dim=1)(trans3)
        return output

model = T_LEAP_SKIP()
input = torch.rand(1,3, 2, 200, 200)
model = model
print("Model output", model(input).shape)

# summary(model,
#         input_size=(3, 2, 200, 200),
#         batch_size=1
#         )
1 Like

You can get the gradients from a tensor by

tensor.grad

For the skip connections, I’d try adding between a Linear or Conv3d keeping the size identical with padding and a relu activation and batchnorm. Highway Networks are an option, too.

But it seems you have an issue elsewhere. I notice your targets are under the variable “labels”. So why are you using MSELoss? If this is a classification problem, shouldn’t you be using CrossEntropyLoss or BCELoss?

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

1 Like

Nevermind, I see the “heatmaps” reference now. So you’re wanting the NN to replicate the heatmap from something like an MRI. Correct? Why are the number of features out equal to 17?
That seems a little odd. A heatmap should have just one channel.

I have a heatmap for each joint, 17 in total that’s why 17 channels. I want to be able to compare 17 channels for each joint from the networks prediction to the target/label for a pose estimation task.

1 Like

Since it’s a 3d mapping of a heatmap, the loss could be “low” simply due to much of the space being zeroes. As dimensionality increases, so does empty space.

As such, the mean loss will be lower.

But generally pose estimation models using heatmaps are trained this way. I woul suspect that the network is too deep and has very little data to learn from.