Custom Loss Function Problem

Hi,
Is there any code line in my loss function to break backpropagation? Because, after started training, loss value plays around the started value, not changed dramatically.
I could not find why.
Thanks in advance.

Train Loses in 10 times epoch:
Train loss: 66.125526
Train loss: 69.031129
Train loss: 66.471873
Train loss: 74.184909
Train loss: 62.160489
Train loss: 63.237526
Train loss: 63.310347
Train loss: 62.977487
Train loss: 71.276782
Train loss: 64.684782

class pairWiseLoss(nn.Module):    
    
    def __init__(self,lambdaValue,lenghtOfHashCode):
        super(pairWiseLoss, self).__init__()
        self.lambdaValue = lambdaValue
        self.l = lenghtOfHashCode
        self.m = lenghtOfHashCode
    
    
    def forward(self, binary1, labels1, binary2, labels2):

        
        similarity = torch.diagonal(torch.mm(labels1, labels2.t()))
        maskCommonLabel = similarity.gt(0.0)
        maskNoCommonLabel = similarity.eq(0.0)
        
        tc = torch.exp(-similarity) * self.lambdaValue * 4 * self.l
        hammingDistance = torch.diagonal( torch.cdist(binary1,binary2, p = 0) )
        
        loss = torch.sum(torch.masked_select(0.5 * F.relu(hammingDistance - tc),maskCommonLabel  )  ) +  torch.sum(torch.masked_select(0.5 * F.relu( self.m - hammingDistance), maskNoCommonLabel ) )
                
        return loss 

I can’t see any line of code, which would detach a tensor from the graph.
If you are concerned about detaching, you could check the .grad attribute of all parameters after the backward call. If they contain valid values, your graph wasn’t detached, and I would recommend to try to overfit a small data sample as a quick test.

@ptrblck thanks for the reply.

        loss.backward()
        print('loss Grad:' , loss.grad)

It returns None. Is it normal?
I checked it with a predefined loss function. Training loss value are decreasing as planned but loss.grad is still None.

Should I write .grad for the network model layers? I could not understand what all parameters mean.

Also, my dataset is small. 500 training, 250 validation and 250 test samples.
When I use the big dataset, the loss value still does not decrease.

The grad attribute will be retained for leaf variables by default.
If you want to print them for the loss, you would need to call loss.retain_grad() before calling loss.baclward().
However, note that this gradient will be 1. by default, if you didn’t pass any manual gradient argument to loss.backward(gradient=).
Here is a small example:

model = models.resnet18()
x = torch.randn(1, 3, 224, 224)
target = torch.zeros(1).long()

criterion = nn.CrossEntropyLoss()

out = model(x)
loss = criterion(out, target)
loss.retain_grad() # use this to print the grad
loss.backward()

print(loss.grad)
> tensor(1.)

To print the gradient of all parameters, you could use this code snippet after calling backward:

# print grads of all parameters
for name, param in model.named_parameters():
    print(name, param.grad.abs().max())

@ptrblck hi again,
as you said, loss.grad returns 1. However, the gradient of the parameters seems there are some problems.
All parameters have the same weight and bias regardless of the epoch except FC.bias tensor.
At least, they seem same but I used your code that mentioned here:

Many parameters have been printed so it means there are some changes between new and old state but it is so small for example:

 old_state_dict['encoder.1.weight']
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

new_state_dict['encoder.1.weight']
tensor([0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997,
        0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997,
        0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997,
        0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997,
        0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997,
        0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997,
        0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997,
        0.9997])

Epoch 0:
FC.bias tensor(8.7036e-05, device=‘cuda:0’)
FC.bias tensor(8.7035e-05, device=‘cuda:0’)
FC.bias tensor(8.7035e-05, device=‘cuda:0’)
Epoch 1:
FC.bias tensor(8.7034e-05, device=‘cuda:0’)
FC.bias tensor(8.7032e-05, device=‘cuda:0’)
FC.bias tensor(8.7031e-05, device=‘cuda:0’)
Epoch 2:
FC.bias tensor(8.7029e-05, device=‘cuda:0’)
FC.bias tensor(8.7028e-05, device=‘cuda:0’)
FC.bias tensor(8.7026e-05, device=‘cuda:0’)

So are these updates enough? Can it be the main reason of not decreasing loss value?

My network model:

class ResNet50PairWise(nn.Module):
    def __init__(self, bits = 16):
        super().__init__()

        resnet = models.resnet50(pretrained=False)

        self.conv1 = nn.Conv2d(12, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.encoder = nn.Sequential(
            self.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
            resnet.avgpool
        )
        self.FC = nn.Linear(2048, bits)


        self.apply(weights_init_kaiming)
        self.apply(fc_init_weights)

    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)

        logits = self.FC(x)
        sign = torch.sign(logits)
        binary_out = torch.relu(sign)

        return binary_out```

It might be the reason and as explained in the other topic, the sign method could kill the gradients.
You could try to increase the learning rate and compare how large the updates would get (you can of course also increase the learning rate for a specific parameter set only).
Alternatively, you could try to use smooth approximations of the sign function, which wouldn’t yield a zero gradient.

1 Like