Hello,

I’m training a model and everything was working fine until I incorporated focal loss. Then all of a sudden the workers keep getting aborted and exiting. I get the error “RuntimeError: DataLoader worker (pid 784890) is killed by signal: Aborted.”

It’s literally the only change I’ve made in the code. Previously I was using cross entropy loss with 8+ workers and it ran fine. As soon as I change it to focal loss it gives the error. I’ve tried lowering the number of workers as down as 1 and lowered the batch size and still get the error. If I run with 0 workers it works, but it’s way too slow.

So now I have two questions. First, what can I do? From what I’ve read of similar issues, it seems to be a problem with memory consumption.

Second question, what’s the matter with focal loss in particular that causes this?

The code I’m using for the loss is this one. I’ve also tried the torchvision implementation and it gave the same issue.

```
class FocalLoss(nn.Module):
""" Focal Loss, as described in https://arxiv.org/abs/1708.02002.
It is essentially an enhancement to cross entropy loss and is
useful for classification tasks when there is a large class imbalance.
x is expected to contain raw, unnormalized scores for each class.
y is expected to contain class labels.
Shape:
- x: (batch_size, C) or (batch_size, C, d1, d2, ..., dK), K > 0.
- y: (batch_size,) or (batch_size, d1, d2, ..., dK), K > 0.
"""
def __init__(self,
alpha: Optional[Tensor] = None,
gamma: float = 0.,
reduction: str = 'mean',
ignore_index: int = -100):
"""Constructor.
Args:
alpha (Tensor, optional): Weights for each class. Defaults to None.
gamma (float, optional): A constant, as described in the paper.
Defaults to 0.
reduction (str, optional): 'mean', 'sum' or 'none'.
Defaults to 'mean'.
ignore_index (int, optional): class label to ignore.
Defaults to -100.
"""
if reduction not in ('mean', 'sum', 'none'):
raise ValueError(
'Reduction must be one of: "mean", "sum", "none".')
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.ignore_index = ignore_index
self.reduction = reduction
self.nll_loss = nn.NLLLoss(
weight=alpha, reduction='none', ignore_index=ignore_index)
def __repr__(self):
arg_keys = ['alpha', 'gamma', 'ignore_index', 'reduction']
arg_vals = [self.__dict__[k] for k in arg_keys]
arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)]
arg_str = ', '.join(arg_strs)
return f'{type(self).__name__}({arg_str})'
def forward(self, x: Tensor, y: Tensor) -> Tensor:
if x.ndim > 2:
# (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C)
c = x.shape[1]
x = x.permute(0, *range(2, x.ndim), 1).reshape(-1, c)
# (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,)
y = y.view(-1)
unignored_mask = y != self.ignore_index
y = y[unignored_mask]
if len(y) == 0:
return torch.tensor(0.)
x = x[unignored_mask]
# compute weighted cross entropy term: -alpha * log(pt)
# (alpha is already part of self.nll_loss)
log_p = F.log_softmax(x, dim=-1)
ce = self.nll_loss(log_p, y)
# get true class column from each row
all_rows = torch.arange(len(x))
log_pt = log_p[all_rows, y]
# compute focal term: (1 - pt)^gamma
pt = log_pt.exp()
focal_term = (1 - pt)**self.gamma
# the full loss: -alpha * ((1 - pt)^gamma) * log(pt)
loss = focal_term * ce
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
```