Why my custom loss function doesn't work for my training process?

Say I have C classes, and I want to get a C*C similarity matrix in which each entry is the cosine similarity between class_i and class_j. I write the below code to compute the similar loss based on the weights of last but one fc layer. Below is the part of the code for simplicity:

cos = nn.CosineSimilarity(dim=1, eps=1e-6)
for batch_idx, (data, target) in enumerate(self.data_loader):
    # C*M
    weights = self.model.module.model.fc[-1].weight
    num_iter = weights.size(0)
    # similarity matrix
    sim_mat = torch.empty(0, num_iter).cuda()
    for j in range(num_iter):
           weights_i = weights[j, :].expand_as(weights)
           sim = cos(weights, weights_i)
           sim = torch.unsqueeze(sim, 0)
           sim_mat = torch.cat((sim_mat, sim), 0)
    sim_mat = sim_mat - torch.diag(torch.diag(sim_mat))

    max_val = torch.max(sim_mat, dim=1).values

    loss_sim = (torch.sum(max_val) / num_classes**2)
    ...
    loss_sim.backward()  

Why my loss_sim doesn’t work for my training process? It seems that the loss_sim didn’t backward properly to affect the original model weights.
I know we can wrap the loss function into a class like:

class Loss_sim(nn.Module):
    def __init__():

    def forward(self, weights):

Should I write like this? Is any problem here because of the copy of the original weights?
Thanks in advance :pray:

Howdy Hoody!

Perhaps the short answer is that your loss_sim is bounded below
by zero, and that might not be what you want.

I’m not sure what you are trying to do here, but I think you
want a loss term that pushes the rows of weights to be dissimilar
from one another.

Your loss_sim will backpropagate. Whether it will do so “properly”
depends on what you are expecting.

These lines:

    sim_mat = sim_mat - torch.diag(torch.diag(sim_mat))
    max_val = torch.max(sim_mat, dim=1).values

place zeros on the diagonal of sim_mat so that max_val can
never be negative.

The problem is that the rows of weights can all have negative
cosine similarity with one another, at which point loss_sim becomes
zero, has zero gradient, and no longer contributes to the training.

After quoting your code, I show a script that runs your version
of loss_sim packaged as a function, cc_sim, and compares it
with two possibly improved versions. cc_simA returns the mean
of the cosine similarities, while cc_simB removes the floor
of zero on loss_sim so that it can become negative and fall
to its most negative similarity (maximum dissimilarity).

This script shows that your version does backpropagate, but does
get stuck at zero, and that the improved versions don’t get stuck
at zero.

It also give an example of a tensor, t, whose rows all have
negative cosine similarity with one another.

Your code:

The script:

import torch
torch.__version__

torch.random.manual_seed (2020)

def cc_sim (weights):
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    num_iter = num_classes = weights.size(0)
    # similarity matrix
    sim_mat = torch.empty(0, num_iter)
    for j in range(num_iter):
        weights_i = weights[j, :].expand_as(weights)
        sim = cos(weights, weights_i)
        sim = torch.unsqueeze(sim, 0)
        sim_mat = torch.cat((sim_mat, sim), 0)
    sim_mat = sim_mat - torch.diag(torch.diag(sim_mat))
    max_val = torch.max(sim_mat, dim=1).values
    loss_sim = (torch.sum(max_val) / num_classes**2)
    return loss_sim

def cc_simA (weights):
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    num_iter = num_classes = weights.size(0)
    # similarity matrix
    sim_mat = torch.empty(0, num_iter)
    for j in range(num_iter):
        weights_i = weights[j, :].expand_as(weights)
        sim = cos(weights, weights_i)
        sim = torch.unsqueeze(sim, 0)
        sim_mat = torch.cat((sim_mat, sim), 0)
    sim_mat = sim_mat - torch.diag(torch.diag(sim_mat))
    loss_sim = torch.mean (sim_mat)
    return loss_sim

def cc_simB (weights):
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    num_iter = num_classes = weights.size(0)
    # similarity matrix
    sim_mat = torch.empty(0, num_iter)
    for j in range(num_iter):
        weights_i = weights[j, :].expand_as(weights)
        sim = cos(weights, weights_i)
        sim = torch.unsqueeze(sim, 0)
        sim_mat = torch.cat((sim_mat, sim), 0)
    sim_mat = sim_mat - torch.diag (float ('inf') * torch.ones (3))
    max_val = torch.max(sim_mat, dim=1).values
    loss_sim = (torch.sum(max_val) / num_classes**2)
    return loss_sim

t = torch.tensor ([[1.0, 0.0], [-0.5, 1.0], [-0.5, -1.0]])
t.requires_grad = True
print ('cc_sim:')
loss = cc_sim (t)
loss.backward()
print ('loss =', loss)
print ('t = ...\n', t)
print ('t.grad = ...\n', t.grad)
with torch.no_grad():
    _ = t.grad.zero_()

print ('cc_simA:')
lossA = cc_simA (t)
lossA.backward()
print ('lossA =', lossA)
print ('t = ...\n', t)
print ('t.grad = ...\n', t.grad)
with torch.no_grad():
    _ = t.grad.zero_()

print ('cc_simB:')
lossB = cc_simB (t)
lossB.backward()
print ('lossB =', lossB)
print ('t = ...\n', t)
print ('t.grad = ...\n', t.grad)

nDim = 2
w = torch.randn ((3, nDim))
wA = w.clone()
wB = w.clone()
w.requires_grad = True
wA.requires_grad = True
wB.requires_grad = True
print ('w = ...\n', w)
print ('wA = ...\n', wA)
print ('wB = ...\n', wB)

lr = 5.0

print ('cc_sim:')
for  i in range (10):
    loss = cc_sim (w)
    print ('loss =', loss)
    if  i != 0:
        _ = w.grad.zero_()
    loss.backward()
    with torch.no_grad():
        _ = w.sub_ (lr * w.grad)

print ('w = ...\n', w)
print ('w.grad = ...\n', w.grad)

print ('cc_simA:')
for  i in range (10):
    lossA = cc_simA (wA)
    print ('lossA =', lossA)
    if  i != 0:
        _ = wA.grad.zero_()
    lossA.backward()
    with torch.no_grad():
        _ = wA.sub_ (lr * wA.grad)

print ('wA = ...\n', wA)
print ('wA.grad = ...\n', wA.grad)


print ('cc_simB:')
for  i in range (10):
    lossB = cc_simB (wB)
    print ('lossB =', lossB)
    if  i != 0:
        _ = wB.grad.zero_()
    lossB.backward()
    with torch.no_grad():
        _ = wB.sub_ (lr * wB.grad)

print ('wB = ...\n', wB)
print ('wB.grad = ...\n', wB.grad)

The output of the script:

>>> import torch
>>> torch.__version__
'1.6.0'
>>> 
>>> torch.random.manual_seed (2020)
<torch._C.Generator object at 0x7f635efa4930>
>>> 
>>> def cc_sim (weights):
...     cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
...     num_iter = num_classes = weights.size(0)
...     # similarity matrix
...     sim_mat = torch.empty(0, num_iter)
...     for j in range(num_iter):
...         weights_i = weights[j, :].expand_as(weights)
...         sim = cos(weights, weights_i)
...         sim = torch.unsqueeze(sim, 0)
...         sim_mat = torch.cat((sim_mat, sim), 0)
...     sim_mat = sim_mat - torch.diag(torch.diag(sim_mat))
...     max_val = torch.max(sim_mat, dim=1).values
...     loss_sim = (torch.sum(max_val) / num_classes**2)
...     return loss_sim
... 
>>> def cc_simA (weights):
...     cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
...     num_iter = num_classes = weights.size(0)
...     # similarity matrix
...     sim_mat = torch.empty(0, num_iter)
...     for j in range(num_iter):
...         weights_i = weights[j, :].expand_as(weights)
...         sim = cos(weights, weights_i)
...         sim = torch.unsqueeze(sim, 0)
...         sim_mat = torch.cat((sim_mat, sim), 0)
...     sim_mat = sim_mat - torch.diag(torch.diag(sim_mat))
...     loss_sim = torch.mean (sim_mat)
...     return loss_sim
... 
>>> def cc_simB (weights):
...     cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
...     num_iter = num_classes = weights.size(0)
...     # similarity matrix
...     sim_mat = torch.empty(0, num_iter)
...     for j in range(num_iter):
...         weights_i = weights[j, :].expand_as(weights)
...         sim = cos(weights, weights_i)
...         sim = torch.unsqueeze(sim, 0)
...         sim_mat = torch.cat((sim_mat, sim), 0)
...     sim_mat = sim_mat - torch.diag (float ('inf') * torch.ones (3))
...     max_val = torch.max(sim_mat, dim=1).values
...     loss_sim = (torch.sum(max_val) / num_classes**2)
...     return loss_sim
... 
>>> t = torch.tensor ([[1.0, 0.0], [-0.5, 1.0], [-0.5, -1.0]])
>>> t.requires_grad = True
>>> print ('cc_sim:')
cc_sim:
>>> loss = cc_sim (t)
>>> loss.backward()
>>> print ('loss =', loss)
loss = tensor(0., grad_fn=<DivBackward0>)
>>> print ('t = ...\n', t)
t = ...
 tensor([[ 1.0000,  0.0000],
        [-0.5000,  1.0000],
        [-0.5000, -1.0000]], requires_grad=True)
>>> print ('t.grad = ...\n', t.grad)
t.grad = ...
 tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])
>>> with torch.no_grad():
...     _ = t.grad.zero_()
... 
>>> print ('cc_simA:')
cc_simA:
>>> lossA = cc_simA (t)
>>> lossA.backward()
>>> print ('lossA =', lossA)
lossA = tensor(-0.3321, grad_fn=<MeanBackward0>)
>>> print ('t = ...\n', t)
t = ...
 tensor([[ 1.0000,  0.0000],
        [-0.5000,  1.0000],
        [-0.5000, -1.0000]], requires_grad=True)
>>> print ('t.grad = ...\n', t.grad)
t.grad = ...
 tensor([[ 0.0000,  0.0000],
        [ 0.0168,  0.0084],
        [ 0.0168, -0.0084]])
>>> with torch.no_grad():
...     _ = t.grad.zero_()
... 
>>> print ('cc_simB:')
cc_simB:
>>> lossB = cc_simB (t)
>>> lossB.backward()
>>> print ('lossB =', lossB)
lossB = tensor(-0.1491, grad_fn=<DivBackward0>)
>>> print ('t = ...\n', t)
t = ...
 tensor([[ 1.0000,  0.0000],
        [-0.5000,  1.0000],
        [-0.5000, -1.0000]], requires_grad=True)
>>> print ('t.grad = ...\n', t.grad)
t.grad = ...
 tensor([[ 0.0000,  0.0994],
        [ 0.1590,  0.0795],
        [ 0.0795, -0.0398]])
>>> 
>>> nDim = 2
>>> w = torch.randn ((3, nDim))
>>> wA = w.clone()
>>> wB = w.clone()
>>> w.requires_grad = True
>>> wA.requires_grad = True
>>> wB.requires_grad = True
>>> print ('w = ...\n', w)
w = ...
 tensor([[ 1.2372, -0.9604],
        [ 1.5415, -0.4079],
        [ 0.8806,  0.0529]], requires_grad=True)
>>> print ('wA = ...\n', wA)
wA = ...
 tensor([[ 1.2372, -0.9604],
        [ 1.5415, -0.4079],
        [ 0.8806,  0.0529]], requires_grad=True)
>>> print ('wB = ...\n', wB)
wB = ...
 tensor([[ 1.2372, -0.9604],
        [ 1.5415, -0.4079],
        [ 0.8806,  0.0529]], requires_grad=True)
>>> 
>>> lr = 5.0
>>> 
>>> print ('cc_sim:')
cc_sim:
>>> for  i in range (10):
...     loss = cc_sim (w)
...     print ('loss =', loss)
...     if  i != 0:
...         _ = w.grad.zero_()
...     loss.backward()
...     with torch.no_grad():
...         _ = w.sub_ (lr * w.grad)
... 
loss = tensor(0.3133, grad_fn=<DivBackward0>)
loss = tensor(0.2794, grad_fn=<DivBackward0>)
loss = tensor(0.2203, grad_fn=<DivBackward0>)
loss = tensor(0.1282, grad_fn=<DivBackward0>)
loss = tensor(0.0270, grad_fn=<DivBackward0>)
loss = tensor(0.0421, grad_fn=<DivBackward0>)
loss = tensor(0., grad_fn=<DivBackward0>)
loss = tensor(0., grad_fn=<DivBackward0>)
loss = tensor(0., grad_fn=<DivBackward0>)
loss = tensor(0., grad_fn=<DivBackward0>)
>>> print ('w = ...\n', w)
w = ...
 tensor([[-0.5426, -1.7768],
        [ 1.8619, -0.0235],
        [-1.0342,  1.1212]], requires_grad=True)
>>> print ('w.grad = ...\n', w.grad)
w.grad = ...
 tensor([[0., 0.],
        [0., 0.],
        [0., 0.]])
>>> 
>>> print ('cc_simA:')
cc_simA:
>>> for  i in range (10):
...     lossA = cc_simA (wA)
...     print ('lossA =', lossA)
...     if  i != 0:
...         _ = wA.grad.zero_()
...     lossA.backward()
...     with torch.no_grad():
...         _ = wA.sub_ (lr * wA.grad)
... 
lossA = tensor(0.5826, grad_fn=<MeanBackward0>)
lossA = tensor(0.1014, grad_fn=<MeanBackward0>)
lossA = tensor(-0.2646, grad_fn=<MeanBackward0>)
lossA = tensor(-0.3099, grad_fn=<MeanBackward0>)
lossA = tensor(-0.3240, grad_fn=<MeanBackward0>)
lossA = tensor(-0.3299, grad_fn=<MeanBackward0>)
lossA = tensor(-0.3322, grad_fn=<MeanBackward0>)
lossA = tensor(-0.3330, grad_fn=<MeanBackward0>)
lossA = tensor(-0.3332, grad_fn=<MeanBackward0>)
lossA = tensor(-0.3333, grad_fn=<MeanBackward0>)
>>> print ('wA = ...\n', wA)
wA = ...
 tensor([[-1.0928, -1.7762],
        [ 1.6100, -0.0603],
        [-0.9511,  1.8146]], requires_grad=True)
>>> print ('wA.grad = ...\n', wA.grad)
wA.grad = ...
 tensor([[ 1.7566e-03, -1.0688e-03],
        [-3.6173e-05, -8.9793e-04],
        [ 1.2300e-03,  6.3939e-04]])
>>> 
>>> 
>>> print ('cc_simB:')
cc_simB:
>>> for  i in range (10):
...     lossB = cc_simB (wB)
...     print ('lossB =', lossB)
...     if  i != 0:
...         _ = wB.grad.zero_()
...     lossB.backward()
...     with torch.no_grad():
...         _ = wB.sub_ (lr * wB.grad)
... 
lossB = tensor(0.3133, grad_fn=<DivBackward0>)
lossB = tensor(0.2794, grad_fn=<DivBackward0>)
lossB = tensor(0.2203, grad_fn=<DivBackward0>)
lossB = tensor(0.1282, grad_fn=<DivBackward0>)
lossB = tensor(0.0040, grad_fn=<DivBackward0>)
lossB = tensor(-0.1192, grad_fn=<DivBackward0>)
lossB = tensor(-0.0680, grad_fn=<DivBackward0>)
lossB = tensor(-0.1348, grad_fn=<DivBackward0>)
lossB = tensor(-0.1079, grad_fn=<DivBackward0>)
lossB = tensor(-0.1362, grad_fn=<DivBackward0>)
>>> print ('wB = ...\n', wB)
wB = ...
 tensor([[-1.0508, -1.8293],
        [ 1.7102,  0.5224],
        [-1.1758,  1.3381]], requires_grad=True)
>>> print ('wB.grad = ...\n', wB.grad)
wB.grad = ...
 tensor([[ 0.0950, -0.0278],
        [ 0.0077, -0.0571],
        [ 0.0489,  0.0305]])
>>> 

Good luck.

K. Frank

1 Like

Thank you for your timely help. Basically, I want to penalize the most similar pairs of each class. That’s why I want to get a similar loss which serves as a regularizing term. My total loss is like L_total = L_CE + L_sim. My questions as follows:

  • I want to know if this term can help to update the parameters of my network in the training process cause I did an assignment for model weights here weights = self.model.module.model.fc[-1].weight.
  • If I want to get a weighed similar loss, is the below code based on yours correct?
import torch
torch.__version__

torch.random.manual_seed (2020)
def cc_sim (weights, w_s):
    '''
    weights: weights from NN model
    w_s: weights for every entry in similarity matrix, w_s.size = (num_classes, num_classes)
    '''
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    num_iter = num_classes = weights.size(0)
    # similarity matrix
    sim_mat = torch.empty(0, num_iter)
    for j in range(num_iter):
        weights_i = weights[j, :].expand_as(weights)
        sim = cos(weights, weights_i)
        sim = torch.unsqueeze(sim, 0)
        sim_mat = torch.cat((sim_mat, sim), 0)
    sim_mat = sim_mat * w_s
    sim_mat = sim_mat - torch.diag(torch.diag(sim_mat))
    max_val = torch.max(sim_mat, dim=1).values
    loss_sim = (torch.sum(max_val) / num_classes**2)
    return loss_sim

def cc_simA (weights, w_s):
    '''
    weights: weights from NN model
    w_s: weights for every entry in similarity matrix, w_s.size = (num_classes, num_classes)
    '''
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    num_iter = num_classes = weights.size(0)
    # similarity matrix
    sim_mat = torch.empty(0, num_iter)
    for j in range(num_iter):
        weights_i = weights[j, :].expand_as(weights)
        sim = cos(weights, weights_i)
        sim = torch.unsqueeze(sim, 0)
        sim_mat = torch.cat((sim_mat, sim), 0)
    # element-wise multiply 
    sim_mat = sim_mat * w_s

    sim_mat = sim_mat - torch.diag(torch.diag(sim_mat))
    loss_sim = torch.mean (sim_mat)
    return loss_sim

def cc_simB (weights, w_s):
    '''
    weights: weights from NN model
    w_s: weights for every entry in similarity matrix, w_s.size = (num_classes, num_classes)
    '''
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    num_iter = num_classes = weights.size(0)
    # similarity matrix
    sim_mat = torch.empty(0, num_iter)
    for j in range(num_iter):
        weights_i = weights[j, :].expand_as(weights)
        sim = cos(weights, weights_i)
        sim = torch.unsqueeze(sim, 0)
        sim_mat = torch.cat((sim_mat, sim), 0)
    # element-wise multiply 
    sim_mat = sim_mat * w_s
    sim_mat = sim_mat - torch.diag (float ('inf') * torch.ones (num_classes))
    max_val = torch.max(sim_mat, dim=1).values
    loss_sim = (torch.sum(max_val) / num_classes**2)
    return loss_sim

if __name__ == '__main__':
    num_classes = nDim = 10
    weights = torch.randn((num_classes, nDim), requires_grad=True)
    w_s = torch.FloatTensor(num_classes, num_classes).uniform_(0, 1) 
    loss = cc_simB(weights, w_s)
    loss.backward()
    print('loss item:', loss)
    print('grad of weights:\n')
    print(weights.grad)

Thanks again!

Howdy Hoody!

If you’re asking whether the assignment:

weights = self.model.module.model.fc[-1].weight

will prevent gradients from flowing back through the assignment
and break backpropagation, the answer is no.

In python, “variables” are references.

self.model.module.model.fc[-1].weight refers to a tensor
in memory somewhere. The above assignment creates a new
reference that refers to the same tensor in memory. No new
tensor is created, nor is any data copied from one place to
another. Performing tensor operations on weights is essentially
identical to performing tensor operations on
self.model.module.model.fc[-1].weight, so backpropagation
will work identically.

To emphasize this point:

wtmp1 = self.model.module.model.fc[-1].weight
wtmp2 = wtmp1
wtmp3 = wtmp2
weights = wtmp3

would be almost equivalent to your assignment, with the only
difference being the three extra (and unnecessary) temporary
references that will be cleaned up when your script exits (or
when an enclosing code block goes out of scope).

This will do what I believe you want. Each element of sim_mat
will be multiplied by the corresponding element of w_s. This
will potentially change the result of torch.max(), depending on
the specific values involved.

Note, if you set up w_s so that it has zeros along its diagonal, the
multiplication will zero out the diagonal of sim_mat so you can forgo
sim_mat = sim_mat - torch.diag(torch.diag(sim_mat)).

Best.

K. Frank

1 Like

Thanks for your clear explanation.

I still have an extra question here:

  • Why my training and validation accuracy curves remained almost the same after I added the sim_loss term like above cc_simB ? (Green is with sim_loss) It’s really wired.