Training fails when using 2 GPUs

Dear all,

I’m using a model named Feature-Tokenizer Transformer that was written in PyTorch. I’m running several experiments on a tabular dataset, and until now everything was working fine. But while doing an ablation study (removing some features to check the performance of the model), my jobs were suddenly crashing when I was removing the first features (I have around 10 633 features, and the first 10 481 can be grouped together, which leave only 152 features). If I remove some other features (the last 152 for example, while keeping the first 10 481) it’s working fine, but not if I remove the first group. After some tests, I realized that it was failing only when on multi GPUs (I’m using torch.nn.DataParallel and for some reason related to the framework that I use I can’t move to torch.nn.parallel.DistributedDataParallel).
Here is the error I receive:

Traceback (most recent call last):
  File "/home/my_user_name/my-framework/", line 38, in <module>
    run(args)  # Run the framework with the parsed arguments
  File "/home/my_user_name/my-framework/", line 32, in run  # Run the pipeline with the provided arguments
  File "/home/my_user_name/my-framework/pipelines/", line 38, in run, X, y, le_species, le_header)  # Run the model's evaluation
  File "/home/my_user_name/my-framework/models/", line 444, in run
    self.evaluate_model(args, X, y, le_species, le_header)  # Run evaluation pipeline
  File "/home/my_user_name/my-framework/models/", line 100, in evaluate_model
    best_fold_accuracy = self.evaluate_fold(args, X, y, le_species, le_header, fold, split_assignments)
  File "/home/my_user_name/my-framework/models/", line 316, in evaluate_fold
    accuracy, model = self.evaluate_epoch(args, model, train_dataloader, test_dataloader, epoch, device, optimizer, criterion, le_header)
  File "/home/my_user_name/my-framework/models/", line 408, in evaluate_epoch
    outputs = model(x_num=batch, x_cat=None)  # Forward pass to get output/logits
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/modules/", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/parallel/", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/parallel/", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/parallel/", line 89, in parallel_apply
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/", line 543, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/parallel/", line 64, in _worker
    output = module(*input, **kwargs)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/modules/", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/rtdl/", line 1487, in forward
    x = self.transformer(x)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/modules/", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/rtdl/", line 1150, in forward
    x_residual, _ = layer['attention'](
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/modules/", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/rtdl/", line 893, in forward
    k = key_compression(k.transpose(1, 2)).transpose(1, 2)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/modules/", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/my_user_name/.conda/envs/my-env/lib/python3.11/site-packages/torch/nn/modules/", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: size mismatch, got 4096, 4096x153,0

And here is my code:

model_parameters = get_model_parameters(args, X_train.shape[1], len(le_header.classes_))   # Get model parameters
model = FTT(**model_parameters)  # Create an instance of the FTT model
device = get_device(args)  # Set the device to run the model on to be GPU
model = torch.nn.DataParallel(model)  # Run the model parallelly  # Move the model to the specified device (GPU or CPU)
criterion = get_criterion(args, y_train)  # Instantiate loss class  # Move the loss function to the specified device (GPU or CPU)
optimizer = get_optimizer(args, model)  # Instantiate optimizer class
scheduler = get_scheduler(args, optimizer)  # Instantiate step learning scheduler class
for epoch in range(args.num_epochs):
    for batch, labels in train_dataloader:  # Iterate through train dataset
        batch = batch.requires_grad_().to(device)  # Load batches with gradient accumulation capabilities
        labels =  # Use GPU for tensors
        optimizer.zero_grad()  # Clear gradients w.r.t. parameters
        outputs = model(x_num=batch, x_cat=None)  # Forward pass to get output/logits
        loss = criterion(outputs, labels)  # Calculate Loss: softmax --> cross entropy loss
        loss.backward()  # Getting gradients w.r.t. parameters
        optimizer.step()  # Updating parameters
# It continues but the crash happens before

The code fails the first time it enters the loop and calculate the forward pass. Once again, I want to insist on these facts:

  • it only happens when I remove the first group of features, i.e., the first 10 481 features. If I remove other features (or if I don’t remove any feature), the code works fine.
  • it only happens if I’m on a cluster with several GPUs. If I run it on my local machine (on which I have only one GPU), the code works fine.
  • the shape of one batch is [batch_size, num_input_features], i.e., [512, 152]. I don’t know why the error is mentioning the number 4096 (153 is normal, it might seem strange but the attribute n_tokens of the model is actually the number of input features + 1).
  • this is what I get if I print the devices of the parameters of my model:
for i in model.named_parameters():
    print(f"{i[0]} -> {i[1].device}")
module.feature_tokenizer.num_tokenizer.weight -> cuda:0
module.feature_tokenizer.num_tokenizer.bias -> cuda:0
module.cls_token.weight -> cuda:0
module.transformer.blocks.0.attention.W_q.weight -> cuda:0
module.transformer.blocks.0.attention.W_q.bias -> cuda:0
module.transformer.blocks.0.attention.W_k.weight -> cuda:0
module.transformer.blocks.0.attention.W_k.bias -> cuda:0
module.transformer.blocks.0.attention.W_v.weight -> cuda:0
module.transformer.blocks.0.attention.W_v.bias -> cuda:0
module.transformer.blocks.0.attention.W_out.weight -> cuda:0
module.transformer.blocks.0.attention.W_out.bias -> cuda:0
module.transformer.blocks.0.ffn.linear_first.weight -> cuda:0
module.transformer.blocks.0.ffn.linear_first.bias -> cuda:0
module.transformer.blocks.0.ffn.linear_second.weight -> cuda:0
module.transformer.blocks.0.ffn.linear_second.bias -> cuda:0
module.transformer.blocks.0.ffn_normalization.weight -> cuda:0
module.transformer.blocks.0.ffn_normalization.bias -> cuda:0
module.transformer.blocks.0.key_compression.weight -> cuda:0
module.transformer.blocks.0.value_compression.weight -> cuda:0
module.transformer.head.normalization.weight -> cuda:0
module.transformer.head.normalization.bias -> cuda:0
module.transformer.head.linear.weight -> cuda:0
module.transformer.head.linear.bias -> cuda:0
  • if I print the shape and device of the batch/labels, I get this:
torch.Size([512, 152])
  • if I print the model, I get this:
  (module): FTT(
    (feature_tokenizer): FeatureTokenizer(
      (num_tokenizer): NumericalFeatureTokenizer()
    (cls_token): CLSToken()
    (transformer): Transformer(
      (blocks): ModuleList(
        (0): ModuleDict(
          (attention): MultiheadAttention(
            (W_q): Linear(in_features=16, out_features=16, bias=True)
            (W_k): Linear(in_features=16, out_features=16, bias=True)
            (W_v): Linear(in_features=16, out_features=16, bias=True)
            (W_out): Linear(in_features=16, out_features=16, bias=True)
            (dropout): Dropout(p=0.3, inplace=False)
          (ffn): FFN(
            (linear_first): Linear(in_features=16, out_features=32, bias=True)
            (activation): ReGLU()
            (dropout): Dropout(p=0.1, inplace=False)
            (linear_second): Linear(in_features=16, out_features=16, bias=True)
          (attention_residual_dropout): Dropout(p=0.0, inplace=False)
          (ffn_residual_dropout): Dropout(p=0.0, inplace=False)
          (output): Identity()
          (ffn_normalization): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
          (key_compression): Linear(in_features=153, out_features=0, bias=False)
          (value_compression): Linear(in_features=153, out_features=0, bias=False)
      (head): Head(
        (normalization): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (activation): ReLU()
        (linear): Linear(in_features=16, out_features=228, bias=True)

Thanks for your help!

Add debug print statements to the forward method of your model to check the shape of all tensors and narrow down where the unexpected shape is created and also which input shape each GPU gets after your slicing operation.

1 Like

Dear @ptrblck,

Thanks for your answer.
To be honest, I was using the method from the repository automatically (without writing it explicitly in my code).
So I copied it from the source code and added print statements as requested:

def forward(self, x_num=None, x_cat=None):
    x = self.feature_tokenizer(x_num, x_cat)
    print("1  --> ", x.shape)
    x = self.cls_token(x)
    print("2  --> ", x.shape)
    x = self.transformer(x)
    print("3  --> ", x.shape)
    return x

Here is the output:

1  -->  torch.Size([256, 152, 16])
1  -->  torch.Size([256, 152, 16])
2  -->  torch.Size([256, 153, 16])
2  -->  torch.Size([256, 153, 16])
3  -->  torch.Size([256, 228])
# After that it crashes and outputs the error message I copied in the original message

For your interested, 256 is my batch size, 152 is the number of input features, 153 is 152+1 which is the number of tokens and 16 is both the size of one token and the input size for the second linear layer in the FFN module.

The stacktrace points to a self.transformer module called from the rtdl package so the forward pass should show the shape change. Could you explain what rtdl is and which module calls it as you might need to dig into it and check the shapes of intermediates there too.

Dear @ptrblck,

Thanks for your answer.
rtdl is a pytorch-based package containing several neural networks that can be used for tabular dataset. For example, the model that I use is a Feature-Tokenizer Transformer, instantiated as follow:

class FTT(rtdl.FTTransformer):
    def __init__(self, n_num_features=None, cat_cardinalities=None, d_token=16, n_blocks=1, attention_n_heads=4, attention_dropout=0.3, attention_initialization='kaiming', attention_normalization='LayerNorm', ffn_d_hidden=16, ffn_dropout=0.1, ffn_activation='ReGLU', ffn_normalization='LayerNorm', residual_dropout=0.0, prenormalization=True, first_prenormalization=False, last_layer_query_idx=[-1], n_tokens=None, kv_compression_ratio=0.004, kv_compression_sharing='headwise', head_activation='ReLU', head_normalization='LayerNorm', d_out=None):
        feature_tokenizer = rtdl.FeatureTokenizer( 
        transformer = rtdl.Transformer( 
            d_out=d_out  # The number of output classes
        super(FTT, self).__init__(feature_tokenizer, transformer) 

    def forward(self, x_num=None, x_cat=None):
        x = self.feature_tokenizer(x_num, x_cat)
        x = self.cls_token(x)
        x = self.transformer(x)
        return x

num_input_features = 10633  # The number of input features in my dataset
num_classes = 228  # The number of classes in my dataset
model = FTT(n_num_features=num_input_features, d_out=num_classes)

The self.transformer module called from the rtdl package is, I think, the forward method of an object of class Transformer:

def forward(self, x: Tensor) -> Tensor:
    assert (
        x.ndim == 3
    ), 'The input must have 3 dimensions: (n_objects, n_tokens, d_token)'
    for layer_idx, layer in enumerate(self.blocks):
        layer = cast(nn.ModuleDict, layer)

        query_idx = (
            self.last_layer_query_idx if layer_idx + 1 == len(self.blocks) else None
        x_residual = self._start_residual(layer, 'attention', x)
        x_residual, _ = layer['attention'](
            x_residual if query_idx is None else x_residual[:, query_idx],
        if query_idx is not None:
            x = x[:, query_idx]
        x = self._end_residual(layer, 'attention', x, x_residual)

        x_residual = self._start_residual(layer, 'ffn', x)
        x_residual = layer['ffn'](x_residual)
        x = self._end_residual(layer, 'ffn', x, x_residual)
        x = layer['output'](x)

    x = self.head(x)
    return x

I’m still having trouble to understand why this work when I have all features, but fails when I do this:

X_train = X_train[:, :10481]  # The first 10 481 features of the dataset represent the same kind of data, so I want to check how removing these features impact the performance of the model
X_test = X_test[:, :10481]
# Then I create a train dataloader and a test dataloader, create the model accordingly (with the new number of input features) and proceed to train the model as shown in the code of the first message

Thanks for the link! Based on your cross-post it seems the linear layer is initialized with zero output features as described here?

1 Like

Dear @ptrblck,

Thanks for your answer. Indeed, the maintainer of the repository rtdl answered me and helped me. The problem was that the output of key_compression (which is called in the forward method of my model) had a size of zero.
However, I don’t understand why it was working when I was on a single GPU. From what I understood, it should always fail, right? I think it’s the main reason why I had trouble debugging the code, I wasn’t understanding why it was working fine on 1x RTX 2080 Ti GPU but failing on 2x RTX 2080 Ti GPUs…