PyTorch not Using Full GPU

I am trying to optimize this script. It runs fine, it’s just too slow. Some specs: I have a GPU with 11 GB of RAM on a server I don’t maintain but have some permissions on. I also have a more than sufficient amount of CPU RAM for the files I’m processing (1.7TB).

I have looked through the forum for fixes to this and added some, but they didn’t seem to help much. The biggest issue is low GPU utilization (~7.5/11 GB). I have an unreasonably huge batch size of 65536, which was running ok yesterday (~18 mins/epoch), but is now taking well over an hour per epoch with the same parameters, and I’m not sure what change I made that could have caused this.

I am using a different file than before, but it’s only a 4:3 difference in size, and the content is essentially no different than the previous file, just more. Presumably, this scales to 24 mins/epoch.

Basically, I’m hoping to get more pairs of eyes on my code to see if anyone has suggestions about how to speed this up.

import os
from collections import OrderedDict
from copy import deepcopy
from datetime import datetime

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizerFast, BertForPreTraining, BertConfig

# paths
proj_dir = '/scratch/ddegenaro'
def in_proj_dir(dir):
    return os.path.join(proj_dir, dir)
pretraining_test = in_proj_dir('pretraining_test.txt')
pretraining_txt = in_proj_dir('pretraining.txt')
inits = in_proj_dir('inits')
ckpts = in_proj_dir('ckpts')
trained = in_proj_dir('trained')

print('Getting tokenizer.')
# get tokenizer and initialize teacher model mBERT
tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased", do_lower_case=False)
print('Done.')
print('Getting mBERT.')
# this line will complain that decoder bias was not in the checkpoint
mBERT = BertForPreTraining.from_pretrained("bert-base-multilingual-cased")
print('Done.')

teacher = mBERT # first network to copy from
MSELoss = torch.nn.MSELoss() # loss between logits of two models
batch_size = 65536 # batch size
epochs = 32 # num epochs

class BertData(Dataset):
    def __init__(self):
        print('Reading in corpus. Warning: requires ~ 50 GB of RAM.')
        self.corpus = open(pretraining_txt).readlines()
        print('Done.')
    def __len__(self):
        return len(self.corpus)
    def __getitem__(self, idx):
      return tokenizer(self.corpus[idx], return_tensors='pt', padding='max_length', truncation=True, max_length=512)

dataset = BertData()

data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=12, pin_memory=True)

for i in reversed(range(2,12)): # TA builder loop

  teacher_state_dict = teacher.state_dict()

  # create a BertConfig with a multilingual vocabulary for the next TA
  config_obj = BertConfig(vocab_size=119547, num_hidden_layers=i)

  student = BertForPreTraining(config_obj) # initialize next model and state dict
  student_state_dict = OrderedDict()

  torch.cuda.empty_cache()

  teacher.to('cuda') # use GPU
  student.to('cuda')

  print('Building student.')
  for key in teacher_state_dict: # copy architecture and weights besides top layer
    if str(i) not in key:
      student_state_dict[key] = deepcopy(teacher_state_dict[key])
  print('Done.')

  # save init for this TA
  print('Saving student.')
  torch.save(student_state_dict, os.path.join(inits, 'ta' + str(i)))
  print('Done.')

  # load next state dict into the next model
  student.load_state_dict(student_state_dict)

  student.train() # ensure training mode

  # generate Adam optimizer close to mBERT's
  optimizer = torch.optim.Adam(student.parameters(), lr=(batch_size/256*1e-4),
                             betas=(0.9, 0.999), eps=1e-06, weight_decay=0)

  optimizer.zero_grad(set_to_none=True) # just to be sure

  with torch.set_grad_enabled(True):

    for k in range(epochs):

      start = datetime.now()

      print(f'Begin epoch {k+1}/{epochs}. Current time: {datetime.now()}.')

      loss = 0 # initialize

      for batch_idx, inputs in enumerate(data_loader):

        for j in inputs:
          inputs[j] = inputs[j][0]
        inputs = inputs.to('cuda')

        # get teacher and student predictions
        teacher_logits = teacher(**inputs).prediction_logits
        student_logits = student(**inputs).prediction_logits
        
        # calculate the loss between them and update
        loss = MSELoss(teacher_logits, student_logits) / batch_size
      
        # learning step
        loss.backward()
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        loss = 0
        print(batch_idx+1, (datetime.now()-start)/(batch_idx+1))
    
      torch.save(student.state_dict(), os.path.join(ckpts, 'ta' + str(i) + '_ckpt' + str(k)))

  # save trained model for this TA
  torch.save(student.state_dict(), os.path.join(trained, 'ta' + str(i)))

  teacher = student # prepare to initialize next network

# end for

I should also add that I tried torch.utils.bottleneck, but it spammed my console continuously until I killed the process. I recently posted a GitHub issue about that.

I don’t know why you are using torch.cuda.empty_cache(), but it will slow down your code and will not avoid any out-of-memory issues (it will allow other applications to use GPU memory in case that’s your use case).

This is a debugging util. which might be helpful to narrow down bottlenecks. However, profiling the code with the PyTorch profiler or e.g. Nsight Systems might give you more information which part of the code is slow.

So it looks like the actual issue is the data loader. I don’t know why it’s so slow, but I tried running the profiler and all the processes inside the actual loop all run in milliseconds or microseconds. So, it would seem my issue is somewhere in here:

class BertData(Dataset):
    def __init__(self):
        print('Reading in corpus. Warning: requires ~ 50 GB of RAM.')
        self.corpus = open(pretraining_txt).readlines()
        print('Done.')
    def __len__(self):
        return len(self.corpus)
    def __getitem__(self, idx):
      return tokenizer(self.corpus[idx], return_tensors='pt', padding='max_length', truncation=True, max_length=512)

dataset = BertData()

data_loader = DataLoader(dataset, batch_size=65536, num_workers=12, pin_memory=True)

I don’t quite know what it would be - if nothing looks amiss here, I guess I’ll look into HuggingFace’s stuff.

You could iterate the DataLoader alone (without any model training) and see how long it would take to create the next batch.
In your current implementation each worker will load 65536 samples and will thus call 65536 times into __getitem__ and tokenizer. I don’t know how expensive the tokenizer calls are, but you should consider this overhead of all 12 workers calling 65k times into it for a single batch.
Would it be possible to call the tokenizer once on the dataset or will it blow up the memory?

So there is a “fast” tokenizer I can use instead which allows batching the inputs to the tokenizer, but it runs out of GPU memory with a batch size greater than 1.

I guess I could try calling tokenizer on everything at the beginning, but are you suggesting storing all the outputs of that in memory? It may be possible, but I just want to be sure we’re on the same page.

Yes, maybe storing the processed data in RAM might work, but I don’t know how large it would be, how expensive each call into the tokenizer is, if it would be possible to use the CPU (and how fast this would be) for the fast tokenizer etc.
I would still recommend to narrow down the actual bottleneck and make sure you are avoiding it.

The tokenizer is not fast enough.
If you look around for any training example, people will normally pre-tokenize the whole dataset and save the integers into a numpy file (or similar) which is normally smaller than the txt file.
Then the torch.Dataset doesn’t actually have to do any processing apart from fetching the data from the memory.
This will also save you some electricity from the cpu processing :slight_smile:

(I’d also suggest using npz+tar files with WebDataset if they don’t fit in memory)

This is the direction I’m trying to go, but unfortunately I’m not too familiar with PyTorch yet. Is there any chance you could direct me to an example of some code to pretokenize? Thank you so much!

Came up with a method that seems to work. Thank you both for this suggestion!

1 Like

Not from the top of my head, but there is a lot around in kaggle, e.g.: first cells in jigsaw20 ds tt6.36 | Kaggle (though I am not sure if this was using multi-threads…)