Seems like there are some in-place functions in my loss functions. Would help greatly if someone could help me locate them. I can’t find them at all!
import torch
class DetectionFocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super(DetectionFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, classification, target):
torch.nn.modules.loss._assert_no_grad(target)
# Gather ancho states from target
# anchor state is used to check how loss should be calculated
# -1: ignore, 0: negative, 1: positive
anchor_state = target[:, :, -1]
target = target[:, :, :-1]
# Filter out ignore anchors
indices = anchor_state != -1
if torch.sum(indices) == 0:
# Return 0 if ignore all
return torch.zeros_like(classification[0, 0, 0])
classification = classification[indices].clone()
target = target[indices].clone()
# compute focal loss
bce = -(target * torch.log(classification) + (1.0 - target) * torch.log(1.0 - classification))
alpha_factor = torch.ones_like(target)
alpha_factor = alpha_factor * self.alpha
alpha_factor[target != 1] = 1 - self.alpha
focal_weight = classification
focal_weight[target == 1] = 1 - focal_weight[target == 1].clone()
focal_weight = alpha_factor * focal_weight ** self.gamma
cls_loss = focal_weight * bce
# Compute the normalizing factor: number of positive anchors
normalizer = torch.sum(anchor_state == 1).float()
normalizer = max(normalizer, 1)
return torch.sum(cls_loss) / normalizer
import torch
class DetectionSmoothL1Loss(torch.nn.Module):
def __init__(self, sigma=3.0):
super(DetectionSmoothL1Loss, self).__init__()
self.sigma_squared = sigma ** 2
def forward(self, regression, target):
torch.nn.modules.loss._assert_no_grad(target)
regression_target = target[:, :, :4]
# anchor state is used to check how loss should be calculated
# -1: ignore, 0: negative, 1: positive
anchor_state = target[:, :, 4]
# filter out "ignore" anchors
indices = anchor_state == 1
if torch.sum(indices) == 0:
# Return 0 if ignore all
return torch.zeros_like(regression[0, 0, 0])
regression = regression[indices].clone()
regression_target = regression_target[indices].clone()
# compute smooth L1 loss
# f(x) = 0.5 * (sigma * x)^2 if |x| < 1 / sigma / sigma
# |x| - 0.5 / sigma / sigma otherwise
regression_diff = regression - regression_target
regression_diff = torch.abs(regression_diff)
to_smooth = regression_diff < 1.0 / self.sigma_squared
regression_loss = torch.zeros_like(regression_diff)
regression_loss[to_smooth] = 0.5 * self.sigma_squared * regression_diff[to_smooth].clone() ** 2
regression_loss[to_smooth == 0] = regression_diff[to_smooth == 0].clone() - 0.5 / self.sigma_squared
# compute the normalizer: the number of positive anchors
normalizer = torch.sum(anchor_state == 1).float()
normalizer = max(normalizer, 1)
return torch.sum(regression_loss) / normalizer