Slow Training Time with CNN

(Because of forum rules, I cannot post 5 images so I am attaching a Google Drive Link with photos named Figure_x.png. Please refer to these. https://drive.google.com/open?id=1iqrjsqHd6wZyliSSTdEfAAkonXYI3sky)

I was encountering some pretty strange runtime issues when training a CNN. For reference, I am training it on a T4 in a Google Cloud Platform VM.

My original architecture was (I used a ResNet style architecture with these layers):

FIGURE 1

I had been training/testing this with different hyperparameters and whatnot since it only took about 30 minutes to an hour to train. All of a sudden, it began taking extremely long so I reduced the architecture to this:

FIGURE 2

I decided to run a profiler on the training loop with only 16 images and these are the results:

FIGURE 3

Since loss.item() and test_loss.item() are only called to gather training and validation statistics, I tried profiling the code without these functions and this happened:

FIGURE 4

Overall, the training loop took the same time to run! This was pretty strange to my boss and me, so thought about it a bit more and came to the conclusion that this might be happening because until loss.item() is called, there are no functions executed in the graph that PyTorch creates.

Next, I decided to look at the CPU, GPU, and GPU memory usage (using nvtop) to see what was going on during this very long inputs.to(device) function call. I did this by incrementing batch_size by 1 from 1 to 16 (at batch_size 17 I got a memory error). CPU and GPU usage remained fairly similar throughout each of the runs, but memory usage had some very strange behavior.

FIGURE 5

As I increased batch size, memory usage increased normally, then drops off around batch_size = 7, then increases normally again until we reach a memory error.

As of now, I am kind of stumped and don’t know what the underlying cause of my slow training and strange memory usage patterns could be. Any help would be greatly appreciated.

Thanks!

This is the training loop:

def train_model(model, criterion, optimizer, scheduler, num_epochs=100):
    epoch_loss = []
    epoch_acc = []
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
            
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'test']:
            
            if phase == 'train':
                model.train()
                running_loss = 0.0
                
            elif phase == 'test':
                model.eval()   # Set model to evaluate mode
                running_test_loss = 0.0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()

                if phase == 'train':
                    with torch.set_grad_enabled(phase=='train'):
                        outputs = model(inputs)
                        loss = criterion(labels, outputs)
                        # backward + optimize only if in training phase
                        loss.backward()
                        optimizer.step()
                        # statistics
                        running_loss += loss.item() * inputs.size(0)
                        
                if phase == 'test':
                    
                    with torch.no_grad():
                        outputs = model(inputs)
                        test_loss = criterion(labels, outputs)
                        running_test_loss += test_loss.item() * inputs.size(0)
                

            if phase == 'train':
                scheduler.step()

        epoch_loss.append(running_loss / dataset_sizes['train'])
        epoch_acc.append(running_test_loss/ dataset_sizes['test'])

        print('Train Loss: {:.4f}\nTest Loss {:.4f}'.format(epoch_loss[epoch], epoch_acc[epoch]))
            
        if epoch % 10==9:
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
            'acc': epoch_acc
            }, os.getcwd() + '/save_model/{}__.pt'.format(epoch))
            
            
    return model, epoch_loss, epoch_acc