Hi,I would like to know the proper way to write the custom loss function and return the correct value for backward pass.
Here’s simplified code based on this repo:
pytorch-retinanet
custom loss function:
class Focal_loss(nn.Module):
def __init__(self,num_classes):
super().__init__()
self.num_classes = num_classes
def binary_focal_loss(self,x,y,stabilization ="None"):
gamma = 2
alpha = 0.25
y_true = one_hot_embedding(y.data.cpu(),self.num_classes+1)
y_true = y_true[:,1:]
y_true = Variable(y_true).cuda()
if stabilization=="sigmoid":
x = x.sigmoid()
x = torch.clamp(x,1e-4, 1.0 - 1e-4)
alpha_weight = alpha*y_true+(1.-x)*(1.-alpha)
focal_weight = y_true*(1.-x)+(1.-y_true)*x
focal_weight = focal_weight*alpha_weight
bce_loss = -y_true*torch.log(x)+(1-y_true)*torch.log(1-y_true)
bce_loss = bce_loss*focal_weight
return bce_loss.sum()
def forward(self,cls_preds,cls_targets):
cls_preds: (tensor) predicted class confidences, sized [batch_size, #anchors, #classes].
cls_targets: (tensor) encoded target labels, sized [batch_size, #anchors].'''
pos = cls_targets>0
#for normalize
num_pos = pos.data.long().sum()
pos_neg = cls_targets>-1
mask_cls_preds = cls_preds[pos_neg]
cls_loss = self.binary_focal_loss(mask_cls_preds,cls_targets[pos_neg])
total_loss = cls_loss/num_pos
return total_loss
Notice that the both cls_preds
and cls_targets
shape start with [batch_size,....]
So,First question:
1.Does this mean that pytorch would automatically pass the batch into network?Should I divide the number of batch after calculate loss?
In this repo,I found that
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=8, collate_fn=trainset.collate_fn)
criterion = FocalLoss()
for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(trainloader):
loss = criterion(cls_preds, cls_targets)
The second question:
The last line of code loss = criterion(cls_preds, cls_targets)
,does this return the loss of single image or the batch of images ?
Thanks in advance!