Why does torch.cuda.empty_cache() make the GPU utilization near 0 and slow down the training time?

Corresponding to the same part of the code, torch.cuda. empty_cache () will make the GPU utilization 0 and make the training time very slow. Why is that?

    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = Warmup[args.schedule](optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)

    model.zero_grad()
    model.train()
    global_step = 0
    for epoch in range(int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
            batch = tuple(t.to(args.device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, _ = batch
            loss = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            if global_step % args.logging_global_step == 0:
                logger.info("Epoch:{}, Global Step:{}/{}, Loss:{:.5f}".format(epoch, 
                                                                              global_step, 
                                                                              num_train_optimization_steps,
                                                                              loss.item()))
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

            torch.cuda.empty_cache()

when use this function, The GPU utilization is basically 0, or it floats around 0 all the time. The nvidia-smi shows:

Every 0.1s: nvidia-smi                                                           
Sat Dec 28 10:40:42 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 430.40       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce RTX 208...  Off  | 00000000:08:00.0 Off |                  N/A |
| 35%   47C    P2    62W / 250W |   7177MiB / 11019MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

The logger

12/28/2019 10:24:31 - INFO - __main__ -   Epoch:0, Global Step:0/1075.0, Loss:4.84432
12/28/2019 10:44:37 - INFO - __main__ -   Epoch:0, Global Step:100/1075.0, Loss:2.23943

HOWERER, when I remove the line torch.cuda. empty_cache ():
the nvidia-smi shows:

Every 0.1s: nvidia-smi                                                    

Sat Dec 28 10:50:23 2019
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.40       Driver Version: 430.40       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce RTX 208...  Off  | 00000000:08:00.0 Off |                  N/A |
| 40%   54C    P2    89W / 250W |  10831MiB / 11019MiB |     66%      Default |

The logger

12/28/2019 10:49:54 - INFO - __main__ -   Epoch:0, Global Step:0/1075.0, Loss:4.84432
12/28/2019 10:51:24 - INFO - __main__ -   Epoch:0, Global Step:100/1075.0, Loss:2.23943
12/28/2019 10:52:48 - INFO - __main__ -   Epoch:0, Global Step:200/1075.0, Loss:1.72447

The virtual environment is:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2018 NVIDIA Corporation
Built on Sat_Aug_25_21:08:01_CDT_2018
Cuda compilation tools, release 10.0, V10.0.130

Python 3.6.9 |Anaconda, Inc.| (default, Jul 30 2019, 19:07:31)
[GCC 7.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.cuda.is_available()
True

@ptrblck Why is this happening? thanks

1 Like

torch.cuda.empty_cache() will, as the name suggests, empty the reusable GPU memory cache. PyTorch uses a custom memory allocator, which reuses freed memory, to avoid expensive and synchronizing cudaMalloc calls.
Since you are freeing this cache, PyTorch needs to reallocate the memory for each new data, which will slow down your code.

4 Likes

In other words, the function torch.cuda.empty_cache() can release the reusable GPU memory cache, but its price is to slow down the code?
I always thought it could speed up pytorch.

Your explanation is correct and I would not recommend to use it, unless you really need to free the cache for whatever reason (e.g. another process needs the memory).

I see, thank you very much for answering my doubts.