class CosineContrastiveLoss(nn.Module):
Cosine contrastive loss function.
Based on:
Maintain 0 for match, 1 for not match.
If they match, loss is 1/4(1-cos_sim)^2.
If they don't, it's cos_sim^2 if cos_sim < margin or 0 otherwise.
Margin in the paper is ~0.4.
def __init__(self, margin=0.4):
super(CosineContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
cos_sim = F.cosine_similarity(output1, output2)
loss_cos_con = torch.mean((1-label) * torch.div(torch.pow((1.0-cos_sim), 2), 4) +
(label) * torch.pow(cos_sim *, self.margin), 2))
return loss_cos_con
However, I’m getting an error saying:
TypeError: mul received an invalid combination of arguments - got (torch.cuda.ByteTensor), but expected one of:
* (float value)
didn't match because some of the arguments have invalid types: (torch.cuda.ByteTensor)
* (torch.cuda.FloatTensor other)
didn't match because some of the arguments have invalid types: (torch.cuda.ByteTensor)
I know that returns a ByteTensor, but if I try to coerce it to a FloatTensor with torch.Tensor.float() I get AttributeError: module 'torch.autograd.variable' has no attribute 'FloatTensor'.
I’m really not sure where to go from here. It seems logical to me to do an element-wise multiplication between the cosine similarity tensor and a tensor with 0 or 1 based on a less-than rule.
You have to coerce the tensor, not the variable. The variable is a container used to build graphs and do the backprop for you. You need to set the variable with a float tensor
Maybe I got your point, sorry. The result of lt is a byte tensor and represents booleans. One option is to get the tensor from the variable using torch.FloatTensor(lt(…).data)
Argh, I’m on my phone and cannot try, what if you do .data.float()? Or instead of float tensor you cast to long tensor? Cannot find info on tensor coercion in the doc
I’m kind of embarrassed to say the answer is just to call .float() directly. As in, self.margin).float(). Thanks for your help on this though!
I also realised that the paper has the sign wrong and you want greater than rather than less than for the non-match loss.