Training takes too long

I am trying to use transfer learning to train an nf-net f4 model for steganography detection as part of an assignment but training is taking too long. Currently, training one epoch takes about 1 hour.

Here is the model definition:

# create model
    model = timm.create_model(model_name,pretrained=True)
    
    # freeze previous layers (for transfer learning)
    for param in model.parameters():
        param.requires_grad = False

    # get number of input features for last layer
    num_features = model.head.fc.in_features

    # modify last layer
    # 2 = modified, not modified
    last_layer = nn.Linear(num_features,4)
  
    # set last layer
    model.head.fc = last_layer

    return model

The training loop is:

# get dataloaders
    _,dataloaders, datasets_size,_ = load_datasets(model_name,batch_size,num_workers)
    
    # to plot losses
    train_losses = []
    val_losses = []

    # timer
    since = time.time()
    
    # best accuracy
    best_acc = 0.0

    # best model weights
    best_model_wts = copy.deepcopy(model.state_dict())
    
    # training loop
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        with open(log_file_path,"a",newline='') as log_file:
            log_file.write('Epoch {}/{} \n'.format(epoch, num_epochs - 1))
            log_file.write("---------- \n")
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)

        # train and validation phase
        for phase in ['train','val']:
            
            # set model mode
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            # current loss and correct accumulator (used for metrics)
            current_loss = 0.0
            current_corrects = 0
            
            current_training_loss = 0.0
            current_val_loss = 0.0
            
            print("Iterating through data")
            
            # iterate over data
            for inputs, labels in dataloaders[phase]:
                # get inputs
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # reset grad
                optimizer.zero_grad()
                
                # perform prediction
                with torch.set_grad_enabled(phase == 'train'):
                    # prediction
                    outputs = model(inputs)
                    
                    # prediction
                    _, preds = torch.max(outputs,1)
                    
                    # loss
                    loss = criterion(outputs,labels)
                    
                    if phase == 'train':
                        # back propagation
                        loss.backward()
                        optimizer.step()
                    
                # compute performance metrics
                current_loss += loss.item() * inputs.size(0)
                current_corrects += torch.sum(preds == labels)
            
            if phase == 'train':
                scheduler.step()
            
            # calculate epoch metrics
            epoch_loss = current_loss / datasets_size[phase]
            epoch_acc = current_corrects.double() /datasets_size[phase]
            
            # store losses values
            if phase == 'train':
                current_training_loss = epoch_loss
                train_losses.append(epoch_loss)
            else:
                current_val_loss = epoch_loss
                val_losses.append(epoch_loss)
            
            if phase == 'val':
                with open(all_epoch_metrics,"a",newline='') as temp_file:
                    write_data = f"{epoch},{epoch_acc},{current_val_loss},{current_training_loss}\n"
                    temp_file.write(write_data)
            
            epoch_end_time = time.time()
            epoch_time_interval = epoch_end_time - epoch_start_time
            # {:.0f}m {:.0f}s".format(time_since // 60, time_since %60)
            with open(log_file_path,"a",newline='') as log_file:
                log_file.write("{} \n".format(phase))
                log_file.write("start time: {:} seconds \n".format(epoch_start_time))
                log_file.write("end time: {:} seconds \n".format(epoch_end_time))
                log_file.write("interval: {:.0f} minutes {:.0f} seconds \n".format(epoch_time_interval // 60, epoch_time_interval % 60))
                log_file.write("{} Loss: {:.4f} Acc: {:.4f} \n".format(phase,epoch_loss,epoch_acc))
                
                print("Start time: {} seconds".format(epoch_start_time))
                print("End Time: {} seconds".format(epoch_end_time))
                print("interval: {:.0f} min {:.0f} sec".format(epoch_time_interval // 60, epoch_time_interval % 60))
                print("{} Loss: {:.4f} Acc: {:.4f}".format(phase,epoch_loss,epoch_acc))

            # check if validation performance improved
            if phase == 'val' and epoch_acc > best_acc:
                # set new best accuracy
                best_acc = epoch_acc
                
                # save model weights
                best_model_wts = copy.deepcopy(model.state_dict())
                
                # save model values
                current_path = f"./best_model_{epoch}.pth"
                
                with open(best_model_metrics_path,"a",newline='') as best_model_metrics_file:
                    data_to_write = f"{current_path},{best_acc}\n"
                    best_model_metrics_file.write(data_to_write)

                # save model
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'current_loss': current_loss,
                    'current_corrects': current_corrects,
                    'loss': criterion
                },current_path)
            
            # check if time to save model
            if epoch % checkpoint_epoch_num == 0 and phase == 'val':
                # save model
                current_path = f"./checkpoint_model_{epoch}.pth"
                
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'current_loss': current_loss,
                    'current_corrects': current_corrects,
                    'loss': criterion
                },current_path)
                
                
        print()

    time_since = time.time() - since
    with open(log_file_path,"a",newline='') as log_file:
        log_file.write("Training complete in {:.0f}m {:.0f}s \n".format(time_since // 60, time_since %60))
        log_file.write("Best val acc: {:.4f} \n".format(best_acc))
    
    print("Training complete in {:.0f}m {:.0f}s".format(time_since // 60, time_since %60))
    print("Best val acc: {:.4f}".format(best_acc))

I am training on Kaggle’s free GPU with the following parameters:

# call driver program above
main(
    model_name="dm_nfnet_f6",
    batch_size=32, # 32,64,128,256,512,1024
    num_workers=2, #constant
    learning_rate=1e-3, # 1e-2,1e-3,1e-4,1e-5,1e-6,1e-7
    plot_loss=True,
    num_epochs=1, # 20 for testing
    step_size=5,
    checkpoint_epoch_num=1 # due to low number of epochs executed
)

I am training on a train dataset contains 29432 images and the test set contains 9800.

Is there any way to speed up training? Are there any issues with the code that might be slowing it down?