Using apex leeads to a CUDA out of memory on an A100

I’m using apex in a conda environment, it was installed from
When I run my training script:

LANG=python                       # set python for py150
#PRETRAINDIR="sberbank-ai/rugpt3large_based_on_gpt2"       # directory of your saved model



python3 -u -m torch.distributed.launch --nproc_per_node=$PER_NODE_GPU  code/ \
        --data_dir=$DATADIR \
        --lit_file=$LITFILE \
        --langs=$LANG \
        --output_dir=$OUTPUTDIR \
        --pretrain_dir=$PRETRAINDIR \
        --log_file=$LOGFILE \
        --model_type=$GPT2MED \
        --block_size=1024 \
        --do_train \
        --evaluate_during_training \
        --do_eval \
        --per_gpu_train_batch_size 1\
        --per_gpu_eval_batch_size=1 \
        --logging_steps=100 \
        --save_steps=500 \
        --save_total_limit 4 \
        --seed=42 \
        --fp16 \
        --gpu_per_node $PER_NODE_GPU \
        --learning_rate=8e-5 \
        --weight_decay=0.01 \
        --evaluate_during_training \
        --per_gpu_train_batch_size=2 \
        --per_gpu_eval_batch_size=4 \
        --gradient_accumulation_steps=4 \
        --num_train_epochs=5 \
        --overwrite_output_dir \
        --not_pretrain \
        --tensorboard_dir ./tensorboard_logs

I run into an error:

/usr/local/lib/python3.8/dist-packages/torch/distributed/ FutureWarning: The module torch.distributed.launch is deprecated
and will be removed in future. Use torchrun.
Note that --use_env is set by default in torchrun.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See for 
further instructions

11/25/2021 16:11:40 - WARNING - __main__ -   Process rank: -1, device: cuda:0, n_gpu: 1, distributed training: False, 16-bits training: True, world size: 1
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
/usr/local/lib/python3.8/dist-packages/transformers/models/auto/ FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.
11/25/2021 16:11:44 - INFO - __main__ -   Model has a total of 355066880 trainable parameters
11/25/2021 16:11:44 - INFO - __main__ -   Training/evaluation parameters Namespace(adam_epsilon=1e-08, block_size=1024, cache_dir='', config_dir=None, data_dir='/home/ubuntu/CodeXGLUE/Code-Code/CodeCompletion-token/dataset/py150/token_completion', device=device(type='cuda', index=0), do_eval=True, do_lower_case=False, do_train=True, eval_all_checkpoints=False, evaluate_during_training=True, fp16=True, fp16_opt_level='O1', gpu_per_node=1, gradient_accumulation_steps=4, langs='python', learning_rate=8e-05, lit_file='/home/ubuntu/CodeXGLUE/Code-Code/CodeCompletion-token/dataset/py150/literals.json', load_name='pretrained', local_rank=-1, log_file='completion_py150_eval.log', logging_steps=100, max_grad_norm=1.0, max_steps=-1, mlm=False, mlm_probability=0.15, model_type='gpt2-medium', n_gpu=1, no_cuda=False, node_index=-1, not_pretrain=True, num_train_epochs=5.0, output_dir='../save/py150', overwrite_cache=False, overwrite_output_dir=True, per_gpu_eval_batch_size=4, per_gpu_train_batch_size=2, pretrain_dir='gpt2-medium', save_steps=500, save_total_limit=4, seed=42, server_ip='', server_port='', start_epoch=0, start_step=0, tensorboard_dir='./tensorboard_logs', tokenizer_dir=None, warmup_steps=0, weight_decay=0.01)
11/25/2021 16:11:44 - WARNING - __main__ -   Loading features from cached file ../save/py150/train_blocksize_1024_wordsize_1_rank_0
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
11/25/2021 16:11:53 - INFO - __main__ -   ***** Running training *****
11/25/2021 16:11:53 - INFO - __main__ -     Num examples = 126276
11/25/2021 16:11:53 - INFO - __main__ -     Num epoch = 4
11/25/2021 16:11:53 - INFO - __main__ -     Instantaneous batch size per GPU = 2
11/25/2021 16:11:53 - INFO - __main__ -     Total train batch size (w. parallel, distributed & accumulation) = 8
11/25/2021 16:11:53 - INFO - __main__ -     Gradient Accumulation steps = 4
11/25/2021 16:11:53 - INFO - __main__ -     Total optimization steps = 78920
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
/usr/local/lib/python3.8/dist-packages/torch/optim/ UserWarning: Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at
  warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0
Traceback (most recent call last):
  File "code/", line 725, in <module>
  File "code/", line 713, in main
    global_step, tr_loss = train(args, train_dataset, model, tokenizer, fh, pool)
  File "code/", line 184, in train
    outputs = model(inputs, labels=labels)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/", line 1073, in forward
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/", line 1150, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/usr/local/lib/python3.8/dist-packages/apex/amp/", line 25, in wrapper
    new_args = utils.casted_args(cast_fn,
  File "/usr/local/lib/python3.8/dist-packages/apex/amp/", line 81, in casted_args
  File "/usr/local/lib/python3.8/dist-packages/apex/amp/", line 74, in maybe_float
    return x.float()
RuntimeError: CUDA out of memory. Tried to allocate 396.00 MiB (GPU 0; 22.20 GiB total capacity; 19.82 GiB already allocated; 170.12 MiB free; 19.88 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1403) of binary: /usr/bin/python3
Traceback (most recent call last):
  File "/usr/lib/python3.8/", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/", line 193, in <module>
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/", line 189, in main
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/", line 174, in launch
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/", line 710, in run
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.8/dist-packages/torch/distributed/launcher/", line 259, in launch_agent
    raise ChildFailedError(
code/ FAILED
Root Cause (first observed failure):
  time      : 2021-11-25_16:12:04
  host      : ip-172-31-4-91.ec2.internal
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1403)
  error_file: <N/A>
  traceback : To enable traceback see:

Which seems ridiculous since A100 is full of memory.

How can this be fixed? Perhaps there’s some working set of pytorch, cuda, transformers and apex that doesn’t run into the error

apex.amp is depredcated in favor of the native mixed-precision utility used via torch.cuda.amp. Take a look at the examples to see its usage.

The linked package is also not official and thus also doesn’t come with pre-built C++/CUDA kernels.