Custom loss function doesn't update weights VGG-16

Before I start, I just want you to know that I’ve read all the previous threads regarding this but still my problem persists. So, I’ve implemented a custom loss function that looks like this:

def Cosine(output, target):
    '''
    Custom loss function with 2 losses:

    - loss_1: penalizes the area out of the unit circle 
    - loss_2: 0 if output = target

    Inputs
        output: predicted phases
        target: true phases  
    ''' 

    # Penalize if output is out of the unit circle   
    squares = output ** 2 # (x ^ 2, y ^ 2)
    loss_1 = ((squares[:, ::2] + squares[:, 1::2]) - 1) ** 2 # (x ^ 2 + y ^ 2 - 1) ** 2

    # Compute the second loss, 1 - cos
    loss_2 =  1. - torch.cos(torch.atan2(output[:, 1::2], output[:, ::2]) - target)  
    
    return torch.mean(loss_1 + loss_2)

I’ve tried the following to debug my model:

for epoch in range(self.num_epochs):
                # Train
                network.train() # keep grads
                print('\nEpoch {}'.format(epoch+1))
                print('\nTrain:\n')
                a = list(network.parameters())[0].clone()
                for batch_idx, (images, labels) in enumerate(Bar(loaders['train'])):
                    images, labels = images.to(self.device, dtype=torch.float), labels.to(self.device, dtype=torch.float) # labels is a tensor of (512, 128) values if we use MyVgg
                    optimizer.zero_grad()
                    preds = network(images) 
                    loss = self.criterion(preds, labels)
                    loss.backward()
                    optimizer.step()
                    
                b = list(network.parameters())[0].clone()     
                print(torch.equal(a.data, b.data))  # This prints always true 
                
                # Validation
                print('\nValid:\n')
                network.eval() # skips dropout and batch_norm 
                with torch.no_grad():
                    for batch_idx, (images, labels) in enumerate(Bar(loaders['valid'])):
                        images, labels = images.to(self.device, dtype=torch.float), labels.to(self.device, 
                        dtype=torch.float) 
                        preds = network(images)
                        loss = self.criterion(preds, labels)

                if self.lr_scheduler: 
                    scheduler.step(loss) # update lr_scheduler

Wherever I try to print list(self.network.parameters())[0].grad is None) I always get True. Also, model.parameters() are always the same and don’t change. Model prediction, true labels, loss functions are all as I expected, but the backprop doesn’t seem to work.

Printing the loss gives me this:

tensor(1.0463, device='cuda:0', grad_fn=<MeanBackward0>)

Hello pvardanis!

From your other threads I have the impression that you are using
a pre-trained VGG-16 network with a little of your own network
architecture on the back end. (Is this correct?)

Is it possible that you are freezing the VGG-16 weights (not
unreasonable) and only updating (“fine-tuning”) the weights of
your back-end architecture? If so, it is possible that
list(network.parameters())[0] is a frozen VGG-16
weight, so that it doesn’t have a grad, and isn’t changed by
your optimization step?

Best.

K. Frank

Hi KFrank,

You are correct. I guess that’s what happens when you’re new to PyTorch haha. I printed the wrong grads. So, I managed to have some plots for all gradients on the classifier, with tensorboard. Again, the model I used is VGG-16 with batch norm and already pre-trained weights. I freezed all the feature layers, and trained the layers on the classifier. And by seeing the values I’m getting suspicious.

Hyperparameters used for this one:

  • num_epochs = 10
  • batch_size = 64
  • lr = 0.01
  • SGD optimizer with default settings

I noticed that if I lower the batch size to 16 or less I get NaN values as output.

Here are the gradients histograms for the 3 layers of the classifier (0, 3 and 5, the layers in between are dropout and relu that’s why they’re not visualized):

And here are the gradients distributions for the 3 layers of the classifier:

This is pretty damn low, or I’m wrong? Seems like a vanishing gradient problem to me, but I’m not sure.

I tried also the following approach:

Tanh activation function, with 1024 neurons at the end. Because of tanh each value is in the range [-1, 1] and each value is in the form of (cos(theta_1), sin(theta_1)) pairs up to theta_512. So, 1024 values in total. And of course, x = cos(theta), y = sin(theta).

Then tried to minimize as before:

  • loss_1 = (x ^ 2 + y ^ 2 - 1) ^ 2, to penalize values outside the unit circle
  • loss_2 = 1 - cos(atan2(y/x) - target), which is 0 if output = target

So loss = loss_1 + loss_2 and that’s what I’m back propagating. Target is only a single value for each transducer in radians.

Unfortunately results are the same. My loss never goes below 1.