Device-side assert triggered ( FocalLoss )

I am trying to use the FocalLoss function I found below. But I keep getting a “device-side assert triggered”. It seems to be on the line where it moves self.alpha to cuda if it’s not there. I can’t understand why doing that would throw an error. If anyone can see what is wrong it would be much appreciated.

class FocalLoss(nn.Module):
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = torch.ones(class_num, 1)
            if isinstance(alpha, torch.Tensor):
                self.alpha = alpha
                self.alpha = torch.tensor(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average
    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)
        class_mask =, C).fill_(0)
        class_mask = class_mask
        ids = targets.view(-1, 1)
        class_mask.scatter_(1,, 1.)
        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha =
        alpha = self.alpha[]
        probs = (P*class_mask).sum(1).view(-1,1)
        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        if self.size_average:
            loss = batch_loss.mean()
            loss = batch_loss.sum()
        return loss

Is your code working fine on the CPU?
This would give you a better error message than the current assert statement.

If it’s working on the CPU, rerun the script with:


and post the stack trace here, please.