Using TPU appears to stuck at second step in training (Transformers model)

Good day to all of you I am pretty new to Parallel and wish to train my model on distributed TPUs. (I am not sure if this is the right place to ask so please redirect me if I am wrong)

My code is basically from some standard tutorial with a slight changes to use custom dataset. The code works well on single GPU say in Colab. However when using TPUs it is able to go through first step in training loop but will deadlock at getting outputs from model in second step.

At first I thought it would be the data sampler part since my dataset is imbalanced and I have been using DistributedSamplerWrapper from Catalyst. However switching Pytorch’s DistributedSampler does not yield any difference.
I also thought maybe the batchsize is too large so I tried difference settings from 64 to say 8, not working…

Data Loader part

## Dataloader ##

class TweetsData(Dataset):
  def __init__(self, dataframe, tokenizer, max_len):
      self.tokenizer = tokenizer
      self.data = dataframe
      self.sentence = dataframe.sentence
      self.targets = self.data.label
      self.max_len = max_len

  def __len__(self):
      return len(self.sentence)

  def __getitem__(self, index):
      sentence = str(self.sentence[index])
      sentence= " ".join(sentence.split())
      inputs = self.tokenizer.encode_plus(
          sentence,
          # Pad to max_length such that tensor can stack in each batches
          padding="max_length",
          truncation=True,
          max_length=self.max_len,
          pad_to_max_length=True
          #return_token_type_ids=True
      )
      ids = inputs['input_ids']
      mask = inputs['attention_mask']
      token_type_ids = inputs["token_type_ids"]

      return {
          'ids': torch.tensor(ids, dtype=torch.long),
          'mask': torch.tensor(mask, dtype=torch.long),
          'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
          'targets': torch.tensor(self.targets[index], dtype=torch.float)
      }

Average results using a function:

  ## Define how loss is averaged out of the 8 TPUs
  def reduce_fn(vals):
    # take average
    return sum(vals) / len(vals)

The training loop (I printed in every step to see which part is stuck):

# Define training loop function

def train_loop_fn(data_loader, model, optimizer, device, scheduler = None):
  tracker = xm.RateTracker()
  model.train() # Put model to training mode
  for bi, data in enumerate(data_loader):
    print("Start")
    start_time = time.time()
    print("Extract data")
    # Extract data
    ids = data['ids'].to(device, dtype = torch.long)
    mask = data['mask'].to(device, dtype = torch.long)
    token_type_ids = data['token_type_ids'].to(device, dtype = torch.long)
    targets = data['targets'].to(device, dtype = torch.long)

    # Reset the gradent
    print("Zero Grad")
    optimizer.zero_grad()

    # Pass ids, mask, token_type_ids to model 
    print("Model")
    outputs = model(ids, mask, token_type_ids)

    # Create loss function (Cross Entropy loss for multi-label classification) and optimizer (using Adam optimizer)
    print("Loss")
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(outputs, targets)   

    # Backprop
    print("Backward")
    loss.backward()
    
    # Use PyTorch XLA optimizer stepping
    print("Step Optimizer")
    xm.optimizer_step(optimizer)

    # Print every 20 steps
    # if bi%20==0:
      # since the loss is on all 8 cores, reduce the loss values and print the average (as defined in reduce_fn)
    print('[xla:{}]({}) Loss={:.5f} Rate={:.2f} GlobalRate={:.2f} Time={}'.format(
      xm.get_ordinal(), bi, loss.item(), tracker.rate(),
      tracker.global_rate(), time.asctime()), flush=True)

    if scheduler is not None:
      scheduler.step()
    end_time = time.time()
    print(f"Time for steps {bi}: {end_time - start_time}")

  # Set model to evaluation mode
  model.eval()

The map_fn function (I only post the class that is related to the problem):

## https://www.kaggle.com/tanlikesmath/the-ultimate-pytorch-tpu-tutorial-jigsaw-xlm-r

def map_fn(index, flags):
  torch.set_default_tensor_type('torch.FloatTensor')

  # Sets a common random seed - both for initialization and ensuring graph is the same
  torch.manual_seed(TORCH_SEED)

  # Acquires the (unique) Cloud TPU core corresponding to this process's index
  device = xm.xla_device()  

  # Use one instances to download datasets 
  if not xm.is_master_ordinal():
    xm.rendezvous('download_only_once')

  train_dataset = pd.read_csv(root_path + "train_set.csv")
  val_dataset = pd.read_csv(root_path + "dev_set.csv")

  if not xm.is_master_ordinal():
    xm.rendezvous('download_only_once')

  tokenizer = tfm.AutoTokenizer.from_pretrained(root_path + "BERTweet_uncased", use_fast=False, return_tensors='pt')

  # Custom dataloader __init__, __len__,  __getitem__ #
  train_set = TweetsData(train_dataset, tokenizer, MAX_LEN)
  val_set = TweetsData(val_dataset, tokenizer, MAX_LEN)

  # Training dataset loader #
  # Wrap our Class imbalance Sampler with DistributedSamplerWrapper

  # train_sampler = DistributedSamplerWrapper(
  #     sampler=BalanceClassSampler(labels=train_dataset.label.values, mode='upsampling'),
  #     num_replicas=xm.xrt_world_size(),
  #     rank=xm.get_ordinal(),
  #     shuffle=True
  # )

  train_sampler = DistributedSampler(
      dataset = train_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True
  )

  train_loader = DataLoader(train_set,
      batch_size=TRAIN_BATCH_SIZE,
      sampler=train_sampler,
      num_workers=NUM_WORKERS_DATA,
      drop_last=True)

  # Validation dataset loader #
  # val_sampler = DistributedSamplerWrapper(
  #   sampler=BalanceClassSampler(labels=val_dataset.label.values , mode='upsampling'),
  #   num_replicas=xm.xrt_world_size(),
  #   rank=xm.get_ordinal(),
  #   shuffle=True
  # )

  val_sampler = DistributedSampler(
      dataset = val_dataset,
      num_replicas=xm.xrt_world_size(),
      rank=xm.get_ordinal(),
      shuffle=True
  )

  val_loader = DataLoader(val_set,
      batch_size=VALID_BATCH_SIZE,
      sampler=val_sampler,
      num_workers=NUM_WORKERS_DATA,
      drop_last=True)

  

  # Push our neural network to TPU 
  model = bertweetClass()
  model.to(device)

  # Don't decay normalized layer
  param_optimizer = list(model.named_parameters()) # model parameters to optimize
  no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
  # apply to weight decay

  optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

  # Create loss function (Cross Entropy loss for multi-label classification) and optimizer (using Adam optimizer)
  optimizer = AdamW(params = optimizer_grouped_parameters , lr = LEARNING_RATE * xm.xrt_world_size())

  # Create number of training steps
  num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS) 

  # Scheduler for optimizer for learning decay
  scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=num_train_steps
  )

  xm.master_print(f"Train for {len(train_dataset)} steps per epoch")
  xm.master_print(f'num_training_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')

  for epoch in range(EPOCHS):
    gc.collect()
    xm.master_print(f"Starting training in epoch: {epoch}")

    ## Training Part ##
    xm.master_print("Entering training loop")
    para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
    gc.collect()
    # Call Training Loop
    train_loop_fn(para_train_loader, model, optimizer, device, scheduler=scheduler)
    del para_train_loader
    gc.collect()

    ## Evaluation Part ##
    para_eval_loader = pl.ParallelLoader(val_loader, [device]).per_device_loader(device)
    xm.master_print("Entering validation loop")
    # Call Evaluation Loop
    model_label, target_label = eval_loop_fn(para_eval_loader, model, device)
    del para_eval_loader
    gc.collect()

    ## Evaluation metrics ##
    ## Reporting Matthews correlation coefficient ##
    epoch_mcc = matthews_corrcoef(target_label, model_label, sample_weight=None)
    epoch_mcc = xm.mesh_reduce("mcc", epoch_mcc, reduce_fn)
    xm.master_print(f"Matthews Coefficent at epoch {epoch} : {epoch_mcc}")
    epoch_f1 = f1_score(target_label, model_label, sample_weight=None)
    epoch_f1 = xm.mesh_reduce("f1", epoch_f1, reduce_fn)
    xm.master_print(f"Matthews Coefficent at epoch {epoch} : {epoch_f1}")

Lastly spawn instances with parameter

## Define key variables to be used in training 
NUM_LABELS = 3
MAX_LEN = 128
TRAIN_BATCH_SIZE = 32
VALID_BATCH_SIZE = 32
EPOCHS = 1
LEARNING_RATE = 3e-05
NUM_WORKERS_DATA = 2
TORCH_SEED = 1234

flags = {}
xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

Here is the interpreter result from running xmp.spawn:

Train for 12638343 steps per epoch
num_training_steps = 789896, world_size=8
Starting training in epoch: 0
Entering training loop
Start
Extract data
Zero Grad
Model
Loss
Backward
Step Optimizer
xla:0 Loss=1.03125 Rate=0.00 GlobalRate=0.00 Time=Fri May 7 12:56:08 2021
Time for steps 0: 8.53129506111145
Start
Extract data
Zero Grad
Model

It will stuck at getting output from model in second step like forever…

Would it be that I am not using the nighty release of pytorch XLA package?

I encounter the same problem as in this thread:
https://stackoverflow.com/questions/67257008/oserror-libmkl-intel-lp64-so-1-cannot-open-shared-object-file-no-such-file-or

and a bug report here:

https://github.com/pytorch/xla/issues/2933

Currently I am using:

!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8.1-cp37-cp37m-linux_x86_64.whl

Sorry for all long codes but help is much appreciated!

Thanks all

Hi, are you initializing process groups (using init_process_group) and using DDP somewhere in your code? DistributedSampler is not intended for use outside of distributed/DDP setting. Does the training loop not hang when removing distributed components?

cc @ailzhang for XLA/TPU question

Hi Rohan many thanks for your reply! I did not initialize init_process_group apparently this is for parallelization in CPU/GPU only? In fact I have not seen this in any XLA tutorial like:

https://www.kaggle.com/tanlikesmath/the-ultimate-pytorch-tpu-tutorial-jigsaw-xlm-r

Regarding the sampler in non distributed setting I do not use DDP but the normal Sampler like WeightedRandomSampler and work flawlessly.

Thanks!

I had the exactly same problem. The training loop just stuck after first step. Not loading new data. I could monitor that the memory was keeping increasing until leaking. I am using pytorch-xla 1.9.

@gabrielwong1991 Have you resolved this problem?

Hi bryan, I can’t exactly remember what I did but you might want to check out how you define your dataset class.

For me, my task was training huggingface BERT model and instead define a dataset class I just use their datasets library.

Thank you for your reply. I just found my problem. It was about the loss backward part rather than dataset loading. It may because of my model itself. I am still investigating it and just opening a thread.