Mixed precision training with Apex

Hello all,
I am trying to convert my model to half-precision followed by mixed-precision training so that lesser memory could be used. I am not sure whether this is the correct way to do it -


    def train_model(self, model, dataloader, epochs):
        cudnn.benchmark = True
        optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
        model.half()
        model ,optimizer = amp.initialize(model, optimizer, opt_level='01')
        model.train()
        model.cuda()
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.1)
        #criterion = torch.nn.CrossEntropyLoss().cuda()
        criterion = torch.nn.BCELoss().cuda()
        for i in range(0, epochs):
            scheduler.step()
            train_accuracy = 0
            net_loss = 0
            for _, (data, label) in enumerate(dataloader):
                optimizer.zero_grad()
                data = data.half().cuda()
                label = label.cuda()
                out = model(data)
                loss = criterion(out, label)
                with amp.scale_loss(loss,optimizer) as scaled_loss:
                    scaled_loss.backward()
                    optimizer.step()
                if torch.argmax(out) == label:
                    train_accuracy +=1
                net_loss += loss.item()
            print('------------------------------------------')
            print('EPOCH ', i)
            print(train_accuracy/len(dataloader))
            print(net_loss/len(dataloader))

Could someone please guide me in the correct direction
Thanks in Advance

nn.BCELoss is most likely unsafe to be used in mixed-precision training and you should get a warning.
Besides that, you shouldn’t call model.half() and push the model to the device before initializing amp.
The docs should give you an example usage.

That being said, we recommend to try out torch.cuda.amp using the nightly binaries or by installing from source, as it can be used directly in PyTorch now without the need to build apex.