Pytorch transformer model crashes after around 30 epochs

I am training a transformer model in sagemaker, and the notebook keeps crashing without any error message. It just goes to idle and hangs, My code is below

Not getting any error message makes it a little hard to debug, but I’m thinking maybe I’m running out of memory on my GPU? I’m not sure how this would happen, since I’m able to run around 30 epochs before the instance restarts and goes to idle. Is there anything in my code that may cause the instance to crash?

def elec(input_ids, labels, mask, validation_inputs, validation_labels, validation_mask, epochs = 50):
    
    
    
    loss_list = []
    lowest_loss = 1000000
    lowest_loss_model = None
    model = ElectraForSequenceClassification.from_pretrained('google/electra-base-discriminator', num_labels = 2,
                                                output_hidden_states=True,
                                                output_attentions=True)
    
    
    model.train()
    
    optimizer_grouped_parameters = [
    {'params': [i for i in model.electra.embeddings.parameters()], 'lr': 1, 'momentum':1},
    {'params': [i for i in model.electra.encoder.parameters()], 'lr': 1, 'momentum':1},
    {'params': [i for i in model.classifier.parameters()], 'lr': 1, 'momentum':1}    
    ]


    optimizer = torch.optim.AdamW(model.parameters(),
                      lr = 5e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                      eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                    )

    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = 0, # Default value in run_glue.py
                                                num_training_steps = epochs)
    
    
    for batch in range(epochs):
        input_ids, labels, mask = shuffle(input_ids, labels, mask)
        
        input_ids_subset = input_ids[:30]
        labels_subset = labels[:30]
        mask_subset = mask[:30]
        
        model.train()
        optimizer.zero_grad()
        
        
        #forward pass
        outputs = model(input_ids_subset, 
                            token_type_ids=None, 
                            attention_mask=mask_subset,
                            labels=labels_subset)
        
        #loss = loss_fn(outputs, labels)

        """ this loss variable will point to the entire model """
        loss = outputs[0]
        """compute gradient. We should now have grad.data in model.parameters()"""
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        model.eval()

        validation_output = model(validation_inputs,
                                  token_type_ids=None, 
                                  attention_mask=validation_mask,
                                    labels = validation_labels
                                  )     
            
        validation_loss = validation_output[0]
        
        if validation_loss < lowest_loss:            
            lowest_loss = validation_loss
            #del(lowest_loss_model)
            #lowest_loss_model = copy.deepcopy(model)
        loss_list.append(validation_loss)
        
        print(batch)
        print(validation_loss)
    
    return lowest_loss_model, lowest_loss, loss_list#, model```

Could you rerun your script in a terminal, as it could show a proper error message. E.g. if you are running out of host memory your OS will kill the process to avoid crashing and the notebook might hide it and just restart.

Yea I’ll try that out. Do you know why it could be running out of memory in the first place? I understand that while it’s computing the gradient for the first time, the memory blows up and the instance crashes. But if it’s able to hold all of that in memory once, why would it not be able to on the 30’th epoch? I don’t see how the memory could be increasing from epoch to epoch.

Usually an increase in memory usage if often caused by storing a tensor which is still attached to the entire computation graph which disallows PyTorch to delete the intermediates.
This might also be the case in your code since you are not wrapping the forward pass of the validation run into a torch.no_grad() guard and are then directly appending validation_loss to a list without detaching it:

loss_list.append(validation_loss)

which will store the entire computation graph from this forward pass.

That seemed to be the issue. Thanks so much for your help!