Train/validation loss not decreasing

Hi, I am taking the output from my final convolutional transpose layer into a softmax layer and then trying to measure the mse loss with my target. Problem is that my loss is doesn’t decrease and is stuck around the same point. I figured the problem is using the softmax in the last layer.

I need the softmax layer in the last layer because I want to measure the probabilities. Any suggestion on how to overcome this problem would be helpful. Thanks

Ex:
input = torch.randn(1, 16, 2, 256, 256)
model output = torch.randn(1, 16, 1, 256, 256)
target = torch.randn(1, 16, 1, 256, 256)
criterion = mse()

1 Like

Hello,

If you want to your network to output a probability distribution, it’s better to use a cross entropy loss (in case you have one_hot target, as for classification, or BCE loss in case of 2 classes) or a KL-divergence loss (in case you have a target distribution). Those losses work generally better than mse in such cases (there are theoretical justification of this claim). Little advice, if you want to use cross entropy loss, do not insert a softmax at the end of your model, CrossEntropyLoss implemented on pytorch works directly with input logits for a better numerical precision and stability.

Hope it helps,
Thomas

Hey Thomas,
thanks for getting back, I am doing a regression-based pose estimation problem. From the paper, I have found they use a mse loss function with a softmax transformation to predict the key points and then compute loss. My only problem is that the loss is stuck when the softmax is present. I want to know if there is some trick when using softmax with mse loss to avoid the loss problem

My model can make predictions but the loss doesnt improve

class model(nn.Module):

    def __init__(self):
        super(model, self).__init__()
        self.encoder = torch.nn.Sequential(
                                          torch.nn.Conv3d(3, 64, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=64),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2)),
                                          
                                          
                                           
                                          torch.nn.Conv3d(64, 128, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),                                         
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(2,2,2),stride=(2,2,2)),
                                          
                                           
                                          torch.nn.Conv3d(128, 256, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2)),
                                         
                                           
                                          torch.nn.Conv3d(256, 512, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=512),
                                          torch.nn.ReLU(inplace=True),
                                          
                                          
                                          )
        
        self.decoder = torch.nn.Sequential(
                                          torch.nn.ConvTranspose3d(512, 256, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1), output_padding=(0,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256), 
                                          torch.nn.ReLU(inplace=True),
                                           
                                                                                  
                                          torch.nn.Conv3d(256, 256, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256),
                                          torch.nn.ReLU(inplace=True),
                                          
                                          
                                           
                                          torch.nn.ConvTranspose3d(256,128, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1),output_padding=(0,1,1)), 
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
                                          
                                           
                                          torch.nn.Conv3d(128, 128, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
 
                                          torch.nn.ConvTranspose3d(128,16, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1),output_padding=(0,1,1))                                                                       
                                          )                  

    def forward(self, image):
        """PUTTING THE MODEL TOGETHER"""
        encoder = self.encoder(image)
        decoder = self.decoder(encoder)
        return torch.nn.Softmax(dim=1)(decoder)


model = model()
if torch.cuda.is_available():
  input = torch.rand(1,3, 2,  256, 256).cuda()
  model = model.cuda()
  print("Output image shape:",model(input).shape)
  
else:
  summary(model,
          input_size=(3, 2, 200, 200),
          batch_size=1
          )
print("************************************************")
# torch.cuda.memory_summary(device=None, abbreviated=False)

From the paper:

  • The last convolutional layer of the decoder is followed by a
    transposed-convolutional layer that uses a linear activation followed by a Softmax transformation, resulting in an output of 17 confidence maps, one per keypoint.

  • The loss function calculated the Mean Squared Error (MSE) per pixel per map between the predicted confidence maps and the ground-truth confidence maps from the
    samples in the batch.

Mmmh, I don’t know such trick. Could you send a link to the paper? I’ll have a look (probably tomorrow, I won’t have much time today).

Yeah sure,

Ok, I had the time to have a quick look. I hope I’m not wrong, but from what I understood the softmax layer is used per map, and not per pixel. If I understood right looking at the generated confidence-maps page 6, the softmax is applied on the spatial dimension, not across the channels, which mean you want a spatial distribution and not a class distribution.

So If you follow the pytorch convention for the dmension of your images (Batch, channels, x, y), then you should apply your softmax as:

return torch.nn.Softmax(dim=(2,3))(decoder)

I hope it should be right like that.

Ah, sorry, softmax dim option cannot take tuple but only int, so you have to flatten your image before computing softmax, something like this should do the trick:

decoder_shape = decoder.shape
flatten_decoder = decoder.view(decoder_shape[0], decoder_shape[1], -1)
faltten_heat_map = torch.nn.Softmax(dim=2)(flatten_decoder )
return faltten_heat_map.view(decoder_shape)

Hey Thomas,
If I do this then i need to reshape my ground truth heatmap too

decoder_shape = decoder.shape
flatten_decoder = decoder.view(decoder_shape[0], decoder_shape[1], -1)
faltten_heat_map = torch.nn.Softmax(dim=2)(flatten_decoder )
return faltten_heat_map.view(decoder_shape)

You do not have to reshape your target because of the last view faltten_heat_map.view(decoder_shape) that reshape correctly the predictions

Hey Thomas, I tried what you suggested and the outcome looks like the given plot

class T_LEAP(nn.Module):
    """T_LEAP ARCHITECTURE"""

    def __init__(self):
        super(T_LEAP, self).__init__()
        self.encoder = torch.nn.Sequential(
                                          torch.nn.Conv3d(3, 64, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=64),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2)),
                                          
                                          
                                           
                                          torch.nn.Conv3d(64, 128, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),                                         
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(2,2,2),stride=(2,2,2)),
                                          
                                           
                                          torch.nn.Conv3d(128, 256, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256),
                                          torch.nn.ReLU(inplace=True),
                                          torch.nn.MaxPool3d(kernel_size=(1,2,2),stride=(1,2,2)),
                                         
                                           
                                          torch.nn.Conv3d(256, 512, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=512),
                                          torch.nn.ReLU(inplace=True),
                                          
                                          
                                          )
        
        self.decoder = torch.nn.Sequential(
                                          torch.nn.ConvTranspose3d(512, 256, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1), output_padding=(0,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256), 
                                          torch.nn.ReLU(inplace=True),
                                           
                                                                                  
                                          torch.nn.Conv3d(256, 256, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=256),
                                          torch.nn.ReLU(inplace=True),
                                          
                                          
                                           
                                          torch.nn.ConvTranspose3d(256,128, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1),output_padding=(0,1,1)), 
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
                                          
                                           
                                          torch.nn.Conv3d(128, 128, kernel_size=(1,3,3), stride=(1,1,1), padding=(1,1,1)),
                                          torch.nn.BatchNorm3d(num_features=128),
                                          torch.nn.ReLU(inplace=True),
 
                                          torch.nn.ConvTranspose3d(128,16, kernel_size=(1,3,3), stride = (1,2,2), padding=(1,1,1),output_padding=(0,1,1))
                                          )                  

    def forward(self, image):
        """PUTTING THE MODEL TOGETHER"""
        encoder = self.encoder(image)
        decoder = self.decoder(encoder)
        decoder_shape = decoder.shape
        flatten_decoder = decoder.view(decoder_shape[0], decoder_shape[1], -1)
        flatten_heat_map = torch.nn.Softmax(dim=2)(flatten_decoder )
        return flatten_heat_map.view(decoder_shape)


model = T_LEAP()
if torch.cuda.is_available():
  input = torch.rand(1,3, 2,  256, 256).cuda()
  model = model.cuda()
  print("Output image shape:",model(input).shape)
  
else:
  summary(model,
          input_size=(3, 2, 200, 200),
          batch_size=1
          )
print("************************************************")
# torch.cuda.memory_summary(device=None, abbreviated=False)

Sample training loop:

#Get parameters, start training model
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS     = 10
lr         = 1e-3
model      = T_LEAP()
criterion  = torch.nn.MSELoss()
optimizer  = torch.optim.Adam(model.parameters(),lr=lr, amsgrad=True)
model = model.to(DEVICE)


# def softXEnt(input, target):
#     logprobs = torch.nn.functional.log_softmax (input, dim = 1)
#     return  -(target * logprobs).mean()

tr_losses =[]
val_losses=[]
def loss_plot(epochs, train, val):
    ep = [i for i in range(epochs)]
    plt.plot(ep,train,label="Training loss")
    plt.plot(ep,val,label="Validation loss")
    plt.title("Training and validation loss")
    plt.show()

image   = torch.rand(1,3,2,160, 160).to(DEVICE)
heatmap = torch.rand(1,16,1,160,160).to(DEVICE)

for idx in range(EPOCHS):
  
  model.train()
  
  # Get images and transfer to GPU
  image   = image
  heatmap = heatmap
  
  optimizer.zero_grad()
  output = model(image)
  loss   = criterion(output,heatmap)
  loss.backward()
  optimizer.step()
  
  curr_trloss = loss.item()
  # Print losses 
  print(f"Training loss {idx}:",curr_trloss/1)
        

# Evaluation loop
  model.eval()
  with torch.no_grad():
    # Get images and transfer to GPU
    image   = image
    heatmap = heatmap
    
    output = model(image)
    loss   = criterion(output,heatmap)
    curr_valoss = loss.item()
    
  print(f"Validation loss {idx}:",curr_valoss/1)
  print(" ")
  tr_losses.append(curr_trloss/1)
  val_losses.append(curr_valoss/1)
    
print("Training and evalutation is now complete")

loss_plot(EPOCHS,tr_losses, val_losses)

Output
yellow - validation curve (sorry about the legends)
image
image

If i understood right, you are doing a dummy training with a constant image and heat map ? I see at least one issue, the heat map is not normalized. When defining your dummy heat map do this:

heatmap = torch.rand(1,16,1,160,160).to(DEVICE)
heatmap = heatmap/torch.sum(dim=(2,3,4), keepdim=True)

But there is maybe something else that disturb the training… Still looking

Someone from the forums suggested me to remove the Softmax and mse loss and try this instead:

def softXEnt(input, target):
     logprobs = torch.nn.functional.log_softmax (input, dim = 1)
     return  -(target * logprobs).mean()

In this case, the loss decreases but the model learns poorly as i can see from the predictions

  • I do normalize my heatmaps while training though.
  • If you remove the softmax loss the loss will start go down though. In the paper they mention about using a linear activation before the softmax, I am assuming it is doing nothing in PyTorch

throws back an error

heatmap = heatmap/torch.sum(dim=(2,3,4), keepdim=True)

Sorry to keep adding info along the way. I thought it could a helpful reference to rule out or consider new possibilities

Yes, working on log probabilities is not a bad idea, but be cautious, in the code snipped you send me the log softmax is applied on channels (dim=1) instead of spatial dimensions.

Yes, sorry, I do some errors when writing code too quickly ^^ I indeed wanted to say this:

heatmap = torch.rand(1,16,1,160,160).to(DEVICE)
heatmap = heatmap/torch.sum(heatmap , dim=(2,3,4), keepdim=True)

And yes, when someone say “linear activation”, usually it means no activation function at output of the layer (as classical layers are linear).

Hey Thomas,
I am still encountering the same issue. Do you think my code is okay so far in comparison with the paper?

At least your training curve seems more consistent now, which is a good news! I don’t understand the validation curve, when you do this (your snipped of code, no modification from my side here):

for idx in range(EPOCHS):
  
  model.train()
  
  # Get images and transfer to GPU
  image   = image
  heatmap = heatmap
  
  optimizer.zero_grad()
  output = model(image)
  loss   = criterion(output,heatmap)
  loss.backward()
  optimizer.step()
  
  curr_trloss = loss.item()
  # Print losses 
  print(f"Training loss {idx}:",curr_trloss/1)
        

# Evaluation loop
  model.eval()
  with torch.no_grad():
    # Get images and transfer to GPU
    image   = image
    heatmap = heatmap
    
    output = model(image)
    loss   = criterion(output,heatmap)
    curr_valoss = loss.item()
    
  print(f"Validation loss {idx}:",curr_valoss/1)
  print(" ")
  tr_losses.append(curr_trloss/1)
  val_losses.append(curr_valoss/1)

The during both the training phase and evaluation phase you are using the same input image and target heatmap. So I don’t get why the validation loss is not equal to the training loss … Are you using this snipped, or running something different?

I am using the same snippet of code you’ve mentioned in your post

Well the only difference I see in training phase VS validation phase is that the batch norm will behave differently in train mode VS eval mode. Could you try running again your code but commenting (just for this test) all the batch norm layers in your model?

I have commented out all of the batchnorm layers from the model and the output looks like:

output with batchnorm commented out and output layer is followed by just torch.nn.Softmax(dim=1)

Ok, so it was the batch norm. Now you’re validation curve and training curves are consistent with respect to each other at least.
I’m just not sure to get the difference between the two experiments. So, in the first one you used the flatten trick to apply softmax on the spatial dimensions, and on the second one you applied the softmax across the channels (dim=1), is that correct?