Expected object of scalar type Float but got scalar type Half for argument #3 'weight'

I’m using CrossEntropyLoss in my object detection model shown as below. Then I’m using APEX amp to enable “O2”-level auto mixed precision training.

class PredictionHead():
    def __init__(self, class_weights):
        super(PredictionHead, self).__init__()

        class_weights = torch.tensor(class_weights, requires_grad=False)
        self.loss_module_class = torch.nn.CrossEntropyLoss(weight=class_weights, reduction="none")

    def forward(self, x):
         # ...

    def loss(self, predictions, labels):
        return self.loss_module_class(predictions, labels)

But I get the error as follows. It seems that APEX amp converted “class_weights” into half floats, but CrossEntryLoss computation requires full floats.

  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py", line 932, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 2317, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 2115, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Float but got scalar type Half for argument #3 'weight' in call to _thnn_nll_loss_forward
  In call to configurable 'train_model' (<function train_model at 0x7f111c9a7f28>) in scope 'train'
  In call to configurable 'train' (<function _run_job at 0x7f1153a08730>)

I also posted in APEX github issue: https://github.com/NVIDIA/apex/issues/837. Any suggestion to skip conversion from full floats to half floats for CrossEntryLoss weights?

It’s working fine now after I move CrossEntropyLoss from init() to loss().