Fine-tuning GPT-2 on multiple GPUs and still not enough of memory

I checked your example code for using torch.nn.parallel.DistributedDataParallel to train model on multiple GPUs on the same host. I’ve modified the code to fine-tune (unsupervised learning) the smallest GPT-2 model and I have 4 x 8GB graphics cards.

I thought 32GB of memory should be enough for the smallest GPT-2 model (even medium should kind of be working?!), but I still get errors like:

...
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [129,0,0], thread: [57,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
...
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_INITIALIZED when calling `cublasCreate(handle)`

The code I use:

import argparse
import os
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config


# Define the training parameters
max_length = 128
learning_rate = 1e-5


class MyTrainDataset(Dataset):
    def __init__(self, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length
        #self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

        # Read and tokenize the text file
        with open("your_dataset.txt", "r", encoding="utf-8") as file:
            self.text = file.read().replace("\n", " ")

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        inputs = self.tokenizer.encode_plus(
            self.text[idx : idx + self.max_length],
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True
        )
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]

        return torch.tensor(input_ids), torch.tensor(attention_mask)


def ddp_setup(rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
    ) -> None:
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.model = DDP(model, device_ids=[gpu_id])

    def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        loss = F.cross_entropy(output, targets)
        loss.backward()
        self.optimizer.step()

    def _run_epoch(self, epoch):
        b_sz = len(next(iter(self.train_data))[0])
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}")
        self.train_data.sampler.set_epoch(epoch)
        for source, targets in self.train_data:
            source = source.to(self.gpu_id)
            targets = targets.to(self.gpu_id)
            self._run_batch(source, targets)

    def train(self, max_epochs: int):
        for epoch in range(max_epochs):
            self._run_epoch(epoch)


def load_train_objs(tokenizer):
    train_set = MyTrainDataset(tokenizer, max_length)  # load your dataset

    config = GPT2Config.from_pretrained("gpt2")
    model = GPT2LMHeadModel(config)
    #model = torch.nn.Linear(20, 1)  # load your model

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    return train_set, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset)
    )


def main(rank: int, world_size: int, total_epochs: int, batch_size: int):
    ddp_setup(rank, world_size)

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    dataset, model, optimizer = load_train_objs(tokenizer)
    train_data = prepare_dataloader(dataset, batch_size)
    trainer = Trainer(model, train_data, optimizer, rank)
    trainer.train(total_epochs)
    destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='simple distributed training job')
    parser.add_argument('total_epochs', type=int, help='Total epochs to train the model')
    parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()

    world_size = torch.cuda.device_count()
    mp.spawn(main, args=(world_size, args.total_epochs, args.batch_size), nprocs=world_size)

I run it with:

python3 training.py --batch_size 1 1

…which is the smallest possible batch size and 1 epoch. Dataset file your_dataset.txt contains only 10+ lines and is 5.1kB in total.

Is my multiprocessing not working or 4x8GB GPUs are definitely not enough for this task?

I would recommend fixing the first indexing error:

../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [129,0,0], thread: [57,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

before checking the memory usage.

Any ideas how exactly? All the used code is posted here

You would need to check the stackrace to narrow down which layer and operation fails.
Use blocking launches via CUDA_LAUNCH_BLOCKING=1 to make sure the stacktrace points to the failing layer.
My guess is an embedding layer is failing due to an invalid input index.

Your code is posted, but not executable, so doesn’t help in debugging.

Unfortunately CUDA_LAUNCH_BLOCKING=1 doesn’t add understanding to me where is an issue.

Here is your_dataset.txt file content from the code:

= Chad at the 2008 Summer Olympics = 
 Chad sent a delegation of two athletes to compete at the 2008 Summer Olympics in Beijing , China : Moumi Sébergué , who competed in the men 's 100 meters , and Hinikissia Albertine Ndikert , who competed in the women 's 100 meters and also bore the Chadian flag during ceremonies . The appearance of this delegation marked the tenth appearance of Chad at the Summer Olympics , the first been in 1964 Summer Olympics in Tokyo , Japan , and its seventh appearance since its Olympic hiatus between 1976 and 1980 . Both Sébergué and Ndikert ranked seventh in their respective heats and did not advance past the qualification round . As of the end of the 2012 London Olympics , there have been no medalists from Chad . 
= = Background = = 
Chad is a landlocked country in Africa whose northern region lies within the eastern reaches of the Sahara Desert and whose southern region lies within the eastern portion of the Sahel . It borders Libya to the south , Niger to the east , Sudan to the west , and the Central African Republic to the north . Chad was originally part of French West Africa until 1960 , when it declared independence . Some four years later , the former French colony made its début at the 1964 Summer Olympics in Tokyo , Japan . For the next three decades , Chad became embroiled in civil war and experienced invasions by Libya and upheavals by Sudanese @-@ backed rebels ; the civil war ended in 1990 , although rebel threats had persisted between then and 2008 . During Chad 's greatest era of instability , athletes from the country did not attend the 1976 Summer Olympics in Montréal , Canada or the 1980 Summer Olympics in Moscow , USSR , although delegations were sent to all other games between 1964 and 2008 . 
The largest Chadian delegation to reach the Olympics appeared in the 1988 Summer Olympics in Seoul , South Korea and at the 1992 Summer Olympics in Barcelona , Spain ; each time , Chad 's National Olympic Committee sent six athletes . During the 1992 games , the NOC sent the nation 's first female Olympian . Since then ( and up to the Beijing games ) , at least one woman has been a part of the Chadian delegation . The smallest contingency of Chadian Olympians occurred during the 2004 Summer Olympics in Athens , Greece , when only Kaltouma Nadjina competed on the country 's behalf . The delegation that arrived in Beijing consisted of two athletes — one man ( 30 @-@ year @-@ old Moumi Sébergué ) and one woman ( 15 @-@ year @-@ old Hinikissia Albertine Ndikert ) , both participants in track events . Ndikert was Chad 's flagbearer at the ceremonies . Up to and including the Beijing games , there has yet to have been a medalist from Chad . 
= = Athletics = = 
Competitors in athletics events could qualify for the next round of competition in two ways . Qualifying by right was posting a high result in their own heat , and qualifying by result was posting a high result in overall standings . Ranks shown are thus those within each heat , not in overall standings . 
Moumi Sébergué represented Chad at the Beijing Olympics in the men 's 100 meters dash . Born in 1977 , Sébergué first participated in the Olympics at age 22 when he raced in the men 's 100 meters at the 2000 Summer Olympics in Sydney , Australia , placing seventh in his qualification heat and not progressing to later rounds . He did not attend the 2004 Summer Olympics in Athens , Greece , but returned to the Olympics at Beijing at the age of 30 . During the course of the August 14 , 2008 races in his event , when the qualification round took place , Sébergué competed in the tenth heat against seven other athletes . He finished the race in 11 @.@ 14 seconds , placing seventh in the heat ahead of Tuvalu 's Okinali Tinilau ( 11 @.@ 48 seconds ) and behind Gabon 's Wilfried Bingangoye ( 10 @.@ 87 seconds ) in a heat led by the Netherlands Antilles ' Churandy Martina ( 10 @.@ 35 seconds ) and Japan 's Naoki Tsukahara ( 10 @.@ 39 seconds ) . Of the 80 athletes who participated in the events , the Chadian sprinter ranked 70th . He did not advance to later rounds . 
Hinikissia Albertine Ndikert competed on Chad 's behalf as the national delegation 's only female athlete at the Beijing games . She participated in the women 's 100 meters dash , and was 15 years old at the time of the competition . Ndikert had not previously competed in any Olympic games . During the qualification round of the event , which took place on August 15 , 2008 , Ndikert competed in the eighth heat against seven other athletes . She finished the race in 12 @.@ 55 seconds , placing seventh ; she defeated the Democratic Republic of the Congo 's Franka Magali ( 12 @.@ 57 seconds ) and fell behind Papua New Guinea 's Mae Koime ( 11 @.@ 68 seconds ) in a heat led by Nigeria 's Damola Osayomi ( 11 @.@ 13 seconds ) and the Bahamas ' Debbie Ferguson @-@ McKenzie ( 11 @.@ 17 seconds ) . Of the event 's 85 competitors , Ndikert finished in 64th place . Therefore , Ndikert did not advance to round two and beyond .

Now you also should be able to cope-paste the code to run it.

As already guessed the embedding layer is failing as it’s using num_embeddings=50257:

DistributedDataParallel(
  (module): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)

thus expecting an input with word indices in [0, 50256] while you are passing input indices 50257 which are out of bounds:

tensor([[ 7632,  4201,  1267,   290,  3214,  2157, 46117,   968, 22777,   705,
            82, 34673, 17634,   524,   357,  1367,  2488,    13,    31,  8257,
          4201,  1267,   287,   257,  4894,  2957,   416, 19398,   705,    82,
          5245,  5708,  8834,   323, 12753,   357,   352, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257,
         50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257]],
       device='cuda:0')

Using:

source[source==50257] = 50256

fixes the indexing error and fails with:

TypeError: cross_entropy_loss(): argument 'input' (position 1) must be Tensor, not CausalLMOutputWithCrossAttentions

next.

It seems, that my custom padding token produced extra ID 50257. So I replaced:

tokenizer.pad_token = tokenizer.eos_token
#tokenizer.add_special_tokens({'pad_token': '[PAD]'})

Entropy’s requirement to receive Tensor I tried to solve this way:

def _run_batch(self, source, targets):
        self.optimizer.zero_grad()
        output = self.model(source)
        output_logits = output.logits
        loss = F.cross_entropy(output_logits, targets)
        loss.backward()
        self.optimizer.step()

The latest error I get is:

...
    loss = F.cross_entropy(output_logits_tensor, targets)
  File "/root/alexey/gpt-2/env/lib/python3.10/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: Expected target size [1, 50257], got [1, 128]

“128” must be a max input string length, changing it to 50257 is simply wrong I guess.

What do I do now?

nn.CrossEntropyLoss expects model outputs containing logits in the shape [batch_size, nb_classes, *] and targets containing class indices in the range [0, nb_classes-1] in the shape [batch_size, *], where * denotes additional dimensions.
Check the model output and target shape and make sure they are using the expected layout.