Working with half model and half input

Hi, I am trying to use the model.half() to get rid of out of memory error. So, first I want to check it for a simple model and then try it for a complex model.

In the following example, I am using model.half()for a simple CNN and MNIST dataset. I think I did all thing correctly, but I got Nan for loss.
I am wondering if you kindly let me know where I am doing a mistake.

mnist = torchvision.datasets.MNIST('./data',train= True,download=True,transform =transforms.ToTensor())
data_loader = DataLoader(mnist,batch_size=20,num_workers=2,shuffle=True)

class Model(nn.Module):
    # Our model

    def __init__(self):
        super(Model, self).__init__()
        
        self.fc1 = nn.Conv2d(1,10,3)
        self.bn1 = nn.BatchNorm2d(10)
        self.fc2= nn.Conv2d(10,20,3)
        self.bn2 = nn.BatchNorm2d(20)
        self.fc3= nn.Linear(11520,10)
        
    def forward(self,x):
        
        x = F.relu(self.fc1(x))
        
        x = x.type(torch.float32)   
        x = self.bn1(x)
                
        x = x.type(torch.float16)
        
        x = F.relu(self.fc2(x))
        
        x = x.type(torch.float32)
        
        x = self.bn2(x)
        
        
        x = x.type(torch.float16)
        
        x = x.view(x.size(0),-1)
                
        x = self.fc3(x)
        return(x)

device = torch.device('cuda:6' if torch.cuda.is_available() else 'cpu')

model = Model().half().to(device)

for layer in model.modules():
    
    if isinstance(layer,nn.BatchNorm2d):
        
        layer.float()

optimizer = optim.Adam(model.parameters(),lr=0.1)

lr_sch = lr_scheduler.StepLR(optimizer,step_size=2,gamma=0.1)

criterion = nn.CrossEntropyLoss()

def train(epoch):
    model.train()
    t_loss = 0
    for X,y in data_loader:   
        X= X.half().to(device)
        y = y.long().to(device)
        pred = model(X)
        loss = criterion(pred,y)
        t_loss+= loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return(t_loss/len(data_loader.dataset))  
num_epochs = 20
train_loss = []
for epoch in range(num_epochs):
    t_loss = train(epoch)
    train_loss.append(t_loss)
    print(f'Epoch: {epoch}: Loss: {t_loss:0.4}')

Epoch: 0: Loss: 4.346e-05
Epoch: 1: Loss: nan
Epoch: 2: Loss: nan
Epoch: 3: Loss: nan
Epoch: 4: Loss: nan
Epoch: 5: Loss: nan
Epoch: 6: Loss: nan
Epoch: 7: Loss: nan
Epoch: 8: Loss: nan
Epoch: 9: Loss: nan
Epoch: 10: Loss: nan
Epoch: 11: Loss: nan
Epoch: 12: Loss: nan
Epoch: 13: Loss: nan
Epoch: 14: Loss: nan
Epoch: 15: Loss: nan
Epoch: 16: Loss: nan
Epoch: 17: Loss: nan
Epoch: 18: Loss: nan
Epoch: 19: Loss: nan

Calling model.half() manually can easily yield NaN and Inf outputs, as some internal values can overflow.
We recommend to use automatic mixed precision training as described here, which takes care of these issues for you.
To use amp you would have to install the nightly binary or build from master.

1 Like