Error at loss.backward() when pretraining Llama model from scratch "TRYING TO BACKWARD SECOND TIME"

I m trying to explore Llama code at https://github.com/facebookresearch/llama.

My training script is shown below:

from datasets import load_dataset
from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel
)
from llama.model_modified import Transformer_modified, ModelArgs
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
import pdb

class Model:
def init(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001)

def custom_loss_function(self, outputs, labels):
    softmax_outputs = F.softmax(outputs, dim=-1)

    # Flatten both the softmax_outputs and labels
    softmax_outputs_flat = softmax_outputs.view(-1, softmax_outputs.size(-1))
    labels_flat = labels.view(-1)

    # Compute the cross-entropy loss
    loss = F.cross_entropy(softmax_outputs_flat, labels_flat, reduction='mean')
    return loss

def mask_tokens(self, input_tensor, mask_prob=0.15):
    """
    Randomly masks tokens in the input tensor for masked language modeling pretraining.

    Args:
        input_tensor (torch.Tensor): Tensor containing input token IDs.
        tokenizer: Pretrained tokenizer.
        mask_prob (float): Probability of masking a token.

    Returns:
        (torch.Tensor, torch.Tensor): Masked input tensor, labels tensor.
    """
    mask = torch.rand(input_tensor.shape) < mask_prob
    masked_tensor = input_tensor.clone()
    masked_tensor[mask] = self.tokenizer.mask_token_id

    # Prepare labels tensor for computing loss
    labels = torch.full_like(input_tensor, fill_value=-100)  # -100 is the default value for ignored index in cross-entropy loss
    labels[mask] = input_tensor[mask]

    return masked_tensor, labels

def count_parameters(self):
    count = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
    count = count / 1000**2
    return count

def train_model(self, data_loader):
    # Perform model training here
    for index, (tensor_input, text) in enumerate(data_loader):
        print(f'Index : {index+1}')
        # print(f'Input : {text}')
        self.optimizer.zero_grad()
        tensor_input, targets = self.mask_tokens(tensor_input.squeeze(0))
        # # Forward pass
        print(f'Input : {tensor_input} has grad : {tensor_input.requires_grad}')
        output = self.model(tokens = tensor_input, start_pos = 0)
        print(f'Output : {output} has grad : {output.requires_grad}')
        # # Loss computation
        loss = self.custom_loss_function(output, targets)
        print(f'Loss : {loss.item()}')
        print(f'Loss has grad : {loss.requires_grad}')
        # print(output._version)
        # ## Backpropagation
        loss.backward()
        # # Weight updates
        self.optimizer.step()

class CustomDataset(Dataset):
def init(self, data, tokenizer, max_length=512):
self.data = data
self.tokenizer = tokenizer
self.sequence_length = max_length

def __len__(self):
    return len(self.data)

def __getitem__(self, index):
    # Get data at the specified index
    self.input_ids = self.tokenizer(self.data['train'][index]['text'], return_tensors="pt", max_length=self.sequence_length, truncation=True, padding='max_length')['input_ids']
    return self.input_ids, self.data['train'][index]['text']

def main():
# Load datasets from the datasets library
tokenizer = AutoTokenizer.from_pretrained(‘bert-base-uncased’)
data = load_dataset(‘rotten_tomatoes’)
sequence_length = 32
dataset = CustomDataset(data=data, tokenizer=tokenizer, max_length=sequence_length)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Initialize model parallel
if not torch.distributed.is_initialized():
    torch.distributed.init_process_group(backend="nccl")
initialize_model_parallel(1)
torch.autograd.set_detect_anomaly(True)
model_args = ModelArgs(dim=32, vocab_size=tokenizer.vocab_size, n_layers=1, n_heads=1, max_seq_len=sequence_length,)
# pdb.set_trace()
m = Model(model=Transformer_modified(model_args), tokenizer=tokenizer)
print(f'Model has {m.count_parameters():.1f}M trainable parameters.')
# model = DDP(model, device_ids=[0], output_device=0)
m.train_model(data_loader)
# clean up
dist.destroy_process_group()

if name == “main”:
main()

I got the RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I can’t find the issue. I have removed the decorator “@torch.inference_mode()” from the Transformer class inside Llama model.py for training.
Can someone explain me what’s going wrong in this?

OUTPUT:

Model has 2.0M trainable parameters.
Index : 1
Input : tensor([[ 101, 1996, 9882, 2135, 9603, 13633, 1997, 1000, 1996, 2935,
1997, 1996, 7635, 1000, 11544, 2003, 2061, 4121, 2008, 1037,
103, 1997, 2616, 3685, 23613, 6235, 2522, 1011, 3213, 1013,
2472, 103]]) has grad : False
Output : tensor([[[-0.0419, -0.7551, -0.4873, …, -0.2259, 0.3079, 0.3085],
[ 0.1105, -0.3129, -1.0512, …, 0.7456, -0.4197, 1.0040],
[-0.1352, 0.1497, -0.6379, …, -0.0768, 1.3803, 0.8941],
…,
[-0.1277, -0.4814, 0.0347, …, 0.3797, -0.1792, 1.4547],
[ 0.2541, 0.9202, -0.2216, …, -1.2154, -1.1796, -1.2808],
[ 0.6918, -0.6579, -0.8302, …, 0.5668, -0.5223, 1.2523]]],
grad_fn=) has grad : True
Loss : 10.326194763183594
Loss has grad : True
Index : 2
Input : tensor([[ 101, 4621, 2021, 2205, 1011, 8915, 23267, 16012, 24330, 102,
0, 0, 0, 0, 0, 103, 0, 0, 0, 0,
0, 103, 0, 0, 0, 0, 103, 0, 0, 0,
0, 0]]) has grad : False
Output : tensor([[[-0.0462, -0.7632, -0.4925, …, -0.2175, 0.3111, 0.3207],
[-0.5318, 0.5070, 0.2948, …, -0.5086, -0.8850, -1.0930],
[-0.9778, -0.4862, 0.1357, …, -0.2190, 0.3070, -0.3706],
…,
[ 1.3116, 0.1696, -0.0291, …, 0.4612, -0.2122, 0.9665],
[ 1.3152, 0.1716, -0.0287, …, 0.4622, -0.2161, 0.9684],
[ 1.3191, 0.1746, -0.0282, …, 0.4631, -0.2187, 0.9687]]],
grad_fn=) has grad : True
Loss : 10.326163291931152
Loss has grad : True
/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/autograd/init.py:251: UserWarning: Error detected in torch::autograd::CopySlices. Traceback of forward call that caused the error:
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 118, in
main()
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 113, in main
m.train_model(data_loader)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 69, in train_model
output = self.model(tokens = tensor_input, start_pos = 0)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/llama/model_modified.py”, line 456, in forward
h = layer(h, start_pos, freqs_cis, mask)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/llama/model_modified.py”, line 372, in forward
h = x + self.attention.forward(
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/llama/model_modified.py”, line 258, in forward
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
(Triggered internally at …/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 118, in
main()
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 113, in main
m.train_model(data_loader)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 77, in train_model
loss.backward()
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/_tensor.py”, line 492, in backward
torch.autograd.backward(
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/autograd/init.py”, line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Hi abbash! I’m trying to modify Llama source code from github for my own project, and this error also happened to me. The first time in the loop loss.backward() works fine but error occurred the second time. Did you solve it? I hope we can get help from each other. From discussions of this error I can approach, It seems that some constant tensors may be involved in the computation of loss, which are freed from computation graph in the first time loss.backward() was called. But I can’t find any tensor like that in my code for days so I guess it’s something in Llama code. I also tried to use detach() on some variables but it didn’t work for me.

Hello Wanghy,

The only way is to use the “retain_graph=True” flag alongside “loss.backward()” which you can use to accumulate the losses within the graph. The way ahead is to “step()” function once the model has seen enough batches of data. I have modified my training loop in which the model is stepped and then saving and loading the model again to resume the training. Although, It doesn’t look smooth, however, it worked for me.

So, setting the flag and modifying training process would solve the problem.

Hope this proves helpful!

Best regards