Error at loss.backward() when pretraining Llama model from scratch "TRYING TO BACKWARD SECOND TIME"

I m trying to explore Llama code at https://github.com/facebookresearch/llama.

My training script is shown below:

from datasets import load_dataset
from fairscale.nn.model_parallel.initialize import (
initialize_model_parallel
)
from llama.model_modified import Transformer_modified, ModelArgs
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
import pdb

class Model:
def init(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001)

def custom_loss_function(self, outputs, labels):
    softmax_outputs = F.softmax(outputs, dim=-1)

    # Flatten both the softmax_outputs and labels
    softmax_outputs_flat = softmax_outputs.view(-1, softmax_outputs.size(-1))
    labels_flat = labels.view(-1)

    # Compute the cross-entropy loss
    loss = F.cross_entropy(softmax_outputs_flat, labels_flat, reduction='mean')
    return loss

def mask_tokens(self, input_tensor, mask_prob=0.15):
    """
    Randomly masks tokens in the input tensor for masked language modeling pretraining.

    Args:
        input_tensor (torch.Tensor): Tensor containing input token IDs.
        tokenizer: Pretrained tokenizer.
        mask_prob (float): Probability of masking a token.

    Returns:
        (torch.Tensor, torch.Tensor): Masked input tensor, labels tensor.
    """
    mask = torch.rand(input_tensor.shape) < mask_prob
    masked_tensor = input_tensor.clone()
    masked_tensor[mask] = self.tokenizer.mask_token_id

    # Prepare labels tensor for computing loss
    labels = torch.full_like(input_tensor, fill_value=-100)  # -100 is the default value for ignored index in cross-entropy loss
    labels[mask] = input_tensor[mask]

    return masked_tensor, labels

def count_parameters(self):
    count = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
    count = count / 1000**2
    return count

def train_model(self, data_loader):
    # Perform model training here
    for index, (tensor_input, text) in enumerate(data_loader):
        print(f'Index : {index+1}')
        # print(f'Input : {text}')
        self.optimizer.zero_grad()
        tensor_input, targets = self.mask_tokens(tensor_input.squeeze(0))
        # # Forward pass
        print(f'Input : {tensor_input} has grad : {tensor_input.requires_grad}')
        output = self.model(tokens = tensor_input, start_pos = 0)
        print(f'Output : {output} has grad : {output.requires_grad}')
        # # Loss computation
        loss = self.custom_loss_function(output, targets)
        print(f'Loss : {loss.item()}')
        print(f'Loss has grad : {loss.requires_grad}')
        # print(output._version)
        # ## Backpropagation
        loss.backward()
        # # Weight updates
        self.optimizer.step()

class CustomDataset(Dataset):
def init(self, data, tokenizer, max_length=512):
self.data = data
self.tokenizer = tokenizer
self.sequence_length = max_length

def __len__(self):
    return len(self.data)

def __getitem__(self, index):
    # Get data at the specified index
    self.input_ids = self.tokenizer(self.data['train'][index]['text'], return_tensors="pt", max_length=self.sequence_length, truncation=True, padding='max_length')['input_ids']
    return self.input_ids, self.data['train'][index]['text']

def main():
# Load datasets from the datasets library
tokenizer = AutoTokenizer.from_pretrained(‘bert-base-uncased’)
data = load_dataset(‘rotten_tomatoes’)
sequence_length = 32
dataset = CustomDataset(data=data, tokenizer=tokenizer, max_length=sequence_length)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Initialize model parallel
if not torch.distributed.is_initialized():
    torch.distributed.init_process_group(backend="nccl")
initialize_model_parallel(1)
torch.autograd.set_detect_anomaly(True)
model_args = ModelArgs(dim=32, vocab_size=tokenizer.vocab_size, n_layers=1, n_heads=1, max_seq_len=sequence_length,)
# pdb.set_trace()
m = Model(model=Transformer_modified(model_args), tokenizer=tokenizer)
print(f'Model has {m.count_parameters():.1f}M trainable parameters.')
# model = DDP(model, device_ids=[0], output_device=0)
m.train_model(data_loader)
# clean up
dist.destroy_process_group()

if name == “main”:
main()

I got the RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I can’t find the issue. I have removed the decorator “@torch.inference_mode()” from the Transformer class inside Llama model.py for training.
Can someone explain me what’s going wrong in this?

OUTPUT:

Model has 2.0M trainable parameters.
Index : 1
Input : tensor([[ 101, 1996, 9882, 2135, 9603, 13633, 1997, 1000, 1996, 2935,
1997, 1996, 7635, 1000, 11544, 2003, 2061, 4121, 2008, 1037,
103, 1997, 2616, 3685, 23613, 6235, 2522, 1011, 3213, 1013,
2472, 103]]) has grad : False
Output : tensor([[[-0.0419, -0.7551, -0.4873, …, -0.2259, 0.3079, 0.3085],
[ 0.1105, -0.3129, -1.0512, …, 0.7456, -0.4197, 1.0040],
[-0.1352, 0.1497, -0.6379, …, -0.0768, 1.3803, 0.8941],
…,
[-0.1277, -0.4814, 0.0347, …, 0.3797, -0.1792, 1.4547],
[ 0.2541, 0.9202, -0.2216, …, -1.2154, -1.1796, -1.2808],
[ 0.6918, -0.6579, -0.8302, …, 0.5668, -0.5223, 1.2523]]],
grad_fn=) has grad : True
Loss : 10.326194763183594
Loss has grad : True
Index : 2
Input : tensor([[ 101, 4621, 2021, 2205, 1011, 8915, 23267, 16012, 24330, 102,
0, 0, 0, 0, 0, 103, 0, 0, 0, 0,
0, 103, 0, 0, 0, 0, 103, 0, 0, 0,
0, 0]]) has grad : False
Output : tensor([[[-0.0462, -0.7632, -0.4925, …, -0.2175, 0.3111, 0.3207],
[-0.5318, 0.5070, 0.2948, …, -0.5086, -0.8850, -1.0930],
[-0.9778, -0.4862, 0.1357, …, -0.2190, 0.3070, -0.3706],
…,
[ 1.3116, 0.1696, -0.0291, …, 0.4612, -0.2122, 0.9665],
[ 1.3152, 0.1716, -0.0287, …, 0.4622, -0.2161, 0.9684],
[ 1.3191, 0.1746, -0.0282, …, 0.4631, -0.2187, 0.9687]]],
grad_fn=) has grad : True
Loss : 10.326163291931152
Loss has grad : True
/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/autograd/init.py:251: UserWarning: Error detected in torch::autograd::CopySlices. Traceback of forward call that caused the error:
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 118, in
main()
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 113, in main
m.train_model(data_loader)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 69, in train_model
output = self.model(tokens = tensor_input, start_pos = 0)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/llama/model_modified.py”, line 456, in forward
h = layer(h, start_pos, freqs_cis, mask)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1527, in _call_impl
return forward_call(*args, **kwargs)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/llama/model_modified.py”, line 372, in forward
h = x + self.attention.forward(
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/llama/model_modified.py”, line 258, in forward
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
(Triggered internally at …/torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 118, in
main()
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 113, in main
m.train_model(data_loader)
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/train_clean.py”, line 77, in train_model
loss.backward()
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/_tensor.py”, line 492, in backward
torch.autograd.backward(
File “/home/3057693@eeecs.qub.ac.uk/Documents/llama/env/lib/python3.10/site-packages/torch/autograd/init.py”, line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.