Converting model into 16 points precisoin (float16) instead of 32

Hi,
I am trying to train the model on mixed precision, so for the same I am using the command:
model.half()
But I am getting the following error:

So when I convet my input and labels also to half but it seem like the error is caused due to the line:

loss.backward()

I have tried to convert loss back into floating point value and run the same, but still I am getting the following error:

Any suggetions?

1 Like

Calling model.half() would not train the model in mixed-precision, but half precision.
Automatic mixed-precision can be used via torch.cuda.amp.

Could you post your model definition as well as the input shapes so that we can reproduce this error, please?

1 Like

Hi,
I am trying to run the model in half-precision only. Sorry for the confusion.

My model seems to run for a few batches or even at times for an epoch but then it throughs error. The code of my model looks similar to the one below:

        images  = Variable(images).cuda()
        labels  = Variable(labels).cuda()


        net.half()  ## 
        images = images.half() ##

        #forward
        logits = net(images)
        logits = logits.float() ##

        loss = criterion(logits, labels) 

        loss.backward() 
        net.float() ##
        optimizer.step()

Variables are deprecated since PyTorch 0.4 so you can use tensors now.
Could you post an executable code snippet using random tensors, so that we could reproduce the issue and debug further?

My training loop look like code below in actual, the code above was just an example, apologies for the confusion:

def train_classifier(classifier, train_loader, optimizer, criterion):
    classifier.half()
    classifier.train()
    loss = 0.0
    losses = []
    for i, (images, labels) in enumerate(train_loader):
        classifier.half()
        images, labels = images.to(device), labels.float().to(device)
        images = images.half()
        optimizer.zero_grad()
        logits = classifier(images)
        logits = logits.float()
        loss = criterion(logits, labels)
        loss = loss.float()
        loss.backward()
        classifier.float()
        optimizer.step()
        losses.append(loss)
    return torch.stack(losses).mean().item()

Any idea what could be wrong?

I guess the second error might be raised since you are converting the model to half and back to float() after again during the training, which could cause dtype mismatches.
Could you explain your use case of converting the model back and forth and, if possible, post an executable code snippet as simple models (e.g. resnet18) seem to work?

Since this the first time I am trying to convert the model to half precision, so I just followed the post below. And it was converting the model to float and half, back and forth, so I thought this is the correct way.

But I am getting error even on the first epoch if I remove don’t convert back the model back to float. The modified code looks like:

def train_classifier(classifier, train_loader, optimizer, criterion):
    classifier.half()
    classifier.train()
    loss = 0.0
    losses = []
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.float().to(device)
        images = images.half()
        optimizer.zero_grad()
        logits = classifier(images)
        logits = logits.float()
        loss = criterion(logits, labels)
        loss = loss.float()
        loss.backward()
        optimizer.step()
        losses.append(loss)
    return torch.stack(losses).mean().item()

The error which I am getting is:

I would still recommend to use the automatic mixed-precision in case you want a stable FP16 training, where numerical sensitive operations are automatically performed in FP32.

Could you still post the model definition and an executable code snippet to reproduce the issue, since I’m unable to run into this error using standard torchvision models.

I am using resnet34 as my base model, with last few layers as linear layer followed by sigmoid. My code looks like:

temp = nn.Sequential(
          nn.Dropout(p=0.5),
          nn.Linear(in_features=512, out_features=128),
          nn.ReLU(),
          nn.Linear(in_features=128, out_features=17, bias=True),
          nn.Sigmoid()
        )

classifier = torchvision.models.resnet34(pretrained=True)
classifier.fc = temp  

I am using Adam optimizer with BCELoss.
And for every epoch I am just calling the above function

train_classifier(classifier, )

Thanks for the update.
If I run your code snippet, I get invalid outputs after two iterations since the model is overflowing, which is creating an error in the criterion and thus a CUDA assert failure:

/opt/conda/conda-bld/pytorch_1603729047590/work/aten/src/ATen/native/cuda/Loss.cu:102: operator(): block: [0,0,0], thread: [33,0,0] Assertion `input_val >= zero && input_val <= one` failed.

Again, I would advice against using FP16 directly, over/underflows can easily happen.

Hi,

I am trying to run the autograded but facing the following issue:

AttributeError: module 'torch.cuda.amp' has no attribute 'autocast'

Any idea?

Could you update to the latest stable version (.1.7.0) and retry importing it?
torch.cuda.amp.autocast was introduced in 1.6.0, but I would recommend to use the latest version, since it ships with the latest bug fixes and additional features.

I am able to run the auto-cast by the loss is causing an issue, my code looks like:

with torch.cuda.amp.autocast():
          logits = classifier(images)
          loss = criterion(logits, labels)

And I am getting an error on loss calculation. The error mentions that I should use some other loss function other that BCELoss, but I need to have a sigmoid layer just before the output, so what kind of loss should I use as the pytorch is recommending me to use loss with logits.
Or is there some way to make it work.

The error message is:

Your model should return the raw logits and you should use nn.BCEWithLogitsLoss as the criterion.
If you want to see the probabilities, you could still apply torch.sigmoid to them, but don’t pass them to the loss function.

I am using the code below to convert the model into mixed precision. And I have also commented the sigmoid line but still I am facing the issue.

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.float().to(device), labels.float().to(device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            logits = classifier(images)
            loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        # Updates the scale for next iteration.
        scaler.update()

But I am getting the error on line scaler.scale(loss).backward(). I am getting the following error:

Sorry for bugging you.

This code is working fine for me:

temp = nn.Sequential(
          nn.Dropout(p=0.5),
          nn.Linear(in_features=512, out_features=128),
          nn.ReLU(),
          nn.Linear(in_features=128, out_features=17, bias=True),
        )

classifier = models.resnet34(pretrained=True)
classifier.fc = temp

device = 'cuda'  
classifier.to(device)
optimizer = torch.optim.SGD(classifier.parameters(), lr=1e-3)
scaler = torch.cuda.amp.GradScaler()

data = torch.randn(2, 3, 224, 224, device=device)
target = torch.randint(0, 2, (2, 17), device=device).float()
criterion = nn.BCEWithLogitsLoss()

for epoch in range(10):
    optimizer.zero_grad()
    with torch.cuda.amp.autocast():
        logits = classifier(data)
        loss = criterion(logits, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    # Updates the scale for next iteration.
    scaler.update()
    print('epoch {}, loss {:.3f}'.format(epoch, loss.item()))
3 Likes

Hi,

Seem like I was also manually calling half, due to which this error was occurring.
Thank you

During mixed precision training if the models weights aren’t fp16 then it’s not going to be fp16 in the end?

Or is it the case that you can convert to fp16 after training with AMP?