Models can't seem to learn with FP16 training and time taken for training is absurd

I’m trying to benchmark the performance (time taken to train, GPU memory usage) while training models in FP16(“O3”) and FP32(“O0”).

It is known that FP16 takes a comparatively low GPU memory and trains faster in Tensor Core

I Used Tesla T4 from (Google Colab) to train these models which from what i read from internet, supports Tensor Core.

I trained the model in all 4 amp opt levels “O0,1,2,3”.

  1. But I couldn’t see any significant difference in training time, and for the most part model doesn’t learn anything.
  2. I don’t know why those memory values happen to be that.
  3. I can see the gradient vanish to NaN after the first step. I tried adding BatchNorm, no use.

I’m not sure what am I doing wrong or missing.

Any help would be great, thanks in advance.

wandb output link : Weights & Biases
stdout are attached in each run’s page

My Model

class Net(nn.Module):

    def __init__(self):
        super(Net,self).__init__()

        self.convs = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, 3, 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 3, 1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(128, 256, 3, 1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.Dropout(0.25)
        )

        self.linears = nn.Sequential(
            nn.Linear(12*12*256,128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128,10)
        )
    
    def forward(self,x):
        x = self.convs(x)
        x = x.view(x.shape[0],-1)
        x = self.linears(x)
        logits = F.log_softmax(x,dim=1)
        return logits

My Training Loop

class Pipeline:
    def __init__(self,config):
         # do stuff with config 
        # optimizer -> Adam 
        self.loss_fn = nn.NLLLoss()
         
    def trainStep(self,batch,half:bool=False):
        X,y = batch
        X,y = X.to(self.device),y.to(self.device)
        
        # forward
        self.model.train()
        logits = self.model(X)
        loss = self.loss_fn(logits,y) 
        acc = torch.argmax(logits.detach(),dim=1)==y
        acc = torch.mean(acc.float())

        # backward
        with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss:
            scaled_loss.backward()
        self.optimizer.step()

        return loss,acc

    def trainLoop(self,epochs=None,half:bool=False,opt:str="O0",log_config=None):

        if epochs is None:
            epochs = self.epochs

        self.model,self.optimizer = apex.amp.initialize(self.model,self.optimizer,opt_level=opt,loss_scale="dynamic")
        wandb.watch(self.model,self.loss_fn,log="all",log_freq=1)
    
        eg_count = 0
        t = trange(epochs)
        for epoch in t:
            b = tqdm(self.train_loader)
            for batch in b:
                eg_count += len(batch)
                loss,acc = self.trainStep(batch)
                
                self.trainLog(loss,acc,eg_count,epoch)

        torch.cuda.empty_cache()

    def pipe(self,project:str,run:str,half:bool=False,val:bool=False,test:bool=False):
        with wandb.init(project=project,name=run,config=self.config):
            opt = run[-2:]

            self.trainLoop(opt=opt)

        # create new model instance after for next run
        self.model = self.config['modelClass']().to(self.device)    
        self.optimizer = torch.optim.Adam(self.model.parameters(),lr=self.learning_rate)


config = {
    "epochs" : 5,
    "batch_size" : 1024,
    "learning_rate" : 1e-2,
    "half" : False,
    "classes" : 10,
    "dataset" : "CIFAR10",
    "architecture" : "CNN",
    "kernels" : [32,64,128,256],
    "modelClass" : Net
}

pl = Pipeline(config)
pl.pipe("first",run="fp32-O0")
pl.pipe("first",run="fp16-O3")
pl.pipe("first",run="fp-Mixed-O2")
pl.pipe("first",run="fp-Mixed-O1")

apex.amp is deprecated and we thus recommend to use the native mixed-precision training util. via torch.cuda.amp as described here.