CUDA out of memory after performing loss.backward()?

I’m getting a

RuntimeError: CUDA out of memory. Tried to allocate 40.00 MiB (GPU 0; 14.76 GiB total capacity; 12.66 GiB already allocated; 35.75 MiB free; 13.29 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.

After calling loss.backward() for the below code

def train(epochs = epoch, batch_count = amount_per_batch):
    ######
    db = get_data()
    
    test_losses = []
    train_losses = []
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    lowest_loss = 1000000000
    
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl")
    nli_model = get_models('s3://chatbots-for-auto-labeling/t5/encoder_plus_logit_layer_plus_softmax/03_28_2023__03_56_46model.tar.gz').to(device)
    
    optimizer = AdamW(nli_model.parameters(),
                          lr = learning_rate, # previous 8e-6
                          eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                        )
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = 0, # Default value in run_glue.py
                                                num_training_steps = 1)
    
    
    
    down_sampled_db = shuffle(db)
    
    count_epochs = 0
    
    sentences = down_sampled_db['description'].to_list()
    questions = down_sampled_db['keyword'].to_list()
    yes_or_no = down_sampled_db['yes_or_no']
    
    toks = tokenizer(sentences, questions, padding='longest') 
    
    ds = Dataset.from_dict({"x": torch.tensor(toks['input_ids']), "mask": torch.tensor(toks['attention_mask']), 'labels' : torch.tensor([ 0 if i == 'no' else 1 for i in yes_or_no])}).with_format("torch")
    
    dataloader = DataLoader(ds, batch_size=batch_count)
    
    while count_epochs != epochs:
        for batch in dataloader:
           
            nli_model.train()
            x_batch = batch["x"].to(device)
            mask_batch = batch["mask"].to(device)
            labels_batch = batch["labels"].to(device)
            loss = nli_model(x_batch, attention_mask=mask_batch, labels = labels_batch)[0]
            print(loss)
            loss.backward()
            print('completed backwards pass')

            optimizer.step()
            optimizer.zero_grad()

        down_sampled_db = shuffle(db)
        
        sentences = down_sampled_db['description'].to_list()
        questions = down_sampled_db['keyword'].to_list()
        yes_or_no = down_sampled_db['yes_or_no']
        labels = torch.tensor([ 0 if i == 'no' else 1 for i in yes_or_no])

        toks = tokenizer(sentences, questions, padding='longest')
        x = torch.tensor(toks['input_ids'])
        mask = torch.tensor(toks['attention_mask'])

        count_epochs += 1
        print('count_epochs', count_epochs)
            
    print('lowest loss is: ', lowest_loss)
    return model

model = train()

I would expect the error to come during the .backward() function call, since it seems the error should arise from the calculation of the gradient. Is this not the case? Also, is there a way to remedy this? I’m already using a batch size of one, and can’t really use any less tokens(I’m using a transformer) for this training job.

Do you mean the error arises after the following is printed?

If yes, maybe it happens as AdamW (a stateful optimizer) tries to updates the parameters using its internal state and also requires memory to save the internal state.

You can maybe try with a state-less optimizer like SGD (without momentum).

Yes, the end of the forward pass/start of the backward pass is usually where memory usage peaks, so not sure what is happening here, but one way to reduce memory usage is to use something like torch.utils.checkpoint — PyTorch 2.0 documentation which trades compute for memory - instead of saving activations for backward, recompute them during backward.

Do you mean the error arises after the following is printed?

Yes, this is correct.

If yes, maybe it happens as AdamW (a stateful optimizer) tries to updates the parameters using its internal state and also requires memory to save the internal state.

You can maybe try with a state-less optimizer like SGD (without momentum).

I acutally thought about this as well. I’m just having a hard time understanding why this operation would blow up the GPU’s

Yes, the end of the forward pass/start of the backward pass is usually where memory usage peaks, so not sure what is happening here, but one way to reduce memory usage is to use something like torch.utils.checkpoint — PyTorch 2.0 documentation which trades compute for memory - instead of saving activations for backward, recompute them during backward.

Thanks I’ll check this out, but I still don’t understand why CUDA out of memory is occuring after the gradient is calculated

Try doing a torch.cuda.synchronize() before the print?

Still failed with same error

Oh sorry for the confusion, what I mean to ask is if it reached print('completed backwards pass') this time? If not, the OOM could have actually happened during the backward pass.

yea it reaches print(‘completed backwards pass’) but doesn’t reach print(“optimizer zero’d”) in the below code

print('completed backwards pass')

optimizer.step()

optimizer.zero_grad()
print("optimizer zero'd")

So it seems it may be erroring on optimizer.step(). But it seems like it should error when the gradient is computed

@srishti-git1110 already pointed out that your AdamW optmizer will use internal states which will be initialized lazily in the first step call.
This code snippet shows this behavior:

print("{:.3f}MB allocated".format(torch.cuda.memory_allocated()/1024**2))
# 0.0MB allocated

model = models.resnet18().cuda()
print("{:.3f}MB allocated".format(torch.cuda.memory_allocated()/1024**2))
# 44.690MB allocated

x = torch.randn(1, 3, 224, 224, device="cuda")
print("{:.3f}MB allocated".format(torch.cuda.memory_allocated()/1024**2))
# 45.265MB allocated

# no memory increase as running stats are empty
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
print("{:.3f}MB allocated".format(torch.cuda.memory_allocated()/1024**2))
# 45.265MB allocated
print(optimizer.state_dict())
# {'state': {}, 'param_groups': [{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0.01, 'amsgrad': False, 'foreach': None, 'maximize': False, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]}]}

# stores forward activations
out = model(x)
print("{:.3f}MB allocated".format(torch.cuda.memory_allocated()/1024**2))
# 74.011MB allocated

# calculates gradients so should increase by ~param size and delete forward activations
out.mean().backward()
print("{:.3f}MB allocated".format(torch.cuda.memory_allocated()/1024**2))
# 107.254MB allocated

# updates parameters and creates internal states
optimizer.step()
print("{:.3f}MB allocated".format(torch.cuda.memory_allocated()/1024**2))
# 198.693MB allocated
print(optimizer.state_dict())

Thanks for the response. I switched out ADAM optimizer with SGD

optimizer = torch.optim.SGD(nli_model.parameters(), lr=0.01, momentum=0.9)

And I still get the same error. Does the same thing happen with SGD? Also, is this something that could be fixed with model parallelization?

You could use my code to check the memory usage of different optimizers. Model sharing will reduce the memory and might fix your issue, yes.