Does applying iterative pruning cause OOM errors?

I’m trying to reproduce the results from The state of sparsity in deep neural networks using pytorch. In the paper it applies iterative pruning to the transformer network, i.e., it applies a certain amount of pruning every $N number of training steps. If I apply this to the Transformer model (from fairseq) for machine translation, around 3 epochs in, I am confronted with an OOM error as below:
Does anyone have a clue on why this might be happening?

The pseudo-code of my iterative pruning looks something like this:

for n in range(max_training_step):
  trainer.train_step(...)
  if pruning_condition:
    trainer.prune_model(amount...)

torch.cuda.empty_cache() # I added this thinkg that it might help but it did not
  • I ran the code without any pruning subroutines and it works fine.
  • At the initiation of the trainer class, I iterate over all the modules in the model and append the specific weights I want to prune to a list called self.prunable_modules. For every pruning iteration, I simply call prune.global_unstructured(self.modules_to_prune, pruning_method=prune.RandomUnstructured, amount=0.2)
  • There was an issues where when I was applying pruning to the modules the error can't call backward twice, use retain_graph=True for your first backward callappeared. So what I had to do was whenever there was a pruning iteration, I fed inretain_graph=True` to the loss.backward call and this made that specific error to disappear.

Can you give me some suggestions? @Michela

This is what the training logs before raising OOM.

"wpb": "3414.7", "bsz": "106.9", "num_updates": "95950", "lr": "0.000102089", "gnorm": "2.053", "train_wall": "9", "wall": "16074"}
2020-08-05 21:20:24 | INFO | train_inner | {"epoch": 3, "update": 2.417, "sparsity": "0.508", "loss": "7.761", "nll_loss": "6.617", "ppl": "98.18", "wps": "19918.3", "ups": "5.77", "wpb": "3451.4", "bsz": "114.2", "num_updates": "96000", "lr": "0.000102062", "gnorm": "1.997", "train_wall": "9", "wall": "16082"}
2020-08-05 21:20:25 | INFO | fairseq.trainer | NOTE: Weights pruned, type: magnitude, amount: 0.017276782789559547, sparsity: 0.5162936572370858
2020-08-05 21:20:25 | WARNING | fairseq.trainer | OOM: Ran out of memory with exception: CUDA out of memory. Tried to allocate 250.00 MiB (GPU 1; 15.78 GiB total capacity; 13.76 GiB already allocated; 63.19 MiB free; 14.54 GiB reserved in total by PyTorch)
Exception raised from malloc at /pytorch/c10/cuda/CUDACachingAllocator.cpp:272 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x2b99525931e2 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x1e64b (0x2b995233464b in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #2: <unknown function> + 0x1f464 (0x2b9952335464 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x1faa1 (0x2b9952335aa1 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libc10_cuda.so)
frame #4: at::native::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0x11e (0x2b991a20f90e in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #5: <unknown function> + 0xf33949 (0x2b9918649949 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #6: <unknown function> + 0xf4d777 (0x2b9918663777 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0x10e9c7d (0x2b990863ec7d in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x10e9f97 (0x2b990863ef97 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #9: at::empty(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0xfa (0x2b9908749a1a in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #10: <unknown function> + 0x2eeaa8d (0x2b990a43fa8d in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x10e9f97 (0x2b990863ef97 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #12: at::empty(c10::ArrayRef<long>, c10::TensorOptions const&, c10::optional<c10::MemoryFormat>) + 0xfa (0x2b9908749a1a in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #13: at::native::zeros(c10::ArrayRef<long>, c10::TensorOptions const&) + 0x25 (0x2b99083c10c5 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #14: <unknown function> + 0x128b2f3 (0x2b99087e02f3 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x2eb3059 (0x2b990a408059 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0x10ea319 (0x2b990863f319 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #17: at::zeros(c10::ArrayRef<long>, c10::TensorOptions const&) + 0xd5 (0x2b9908734fb5 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #18: torch::autograd::generated::GatherBackward::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x209 (0x2b990a279f89 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #19: <unknown function> + 0x3375bb7 (0x2b990a8cabb7 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #20: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0x1400 (0x2b990a8c6400 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #21: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x451 (0x2b990a8c6fa1 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #22: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x89 (0x2b990a8bf119 in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_cpu.so)
frame #23: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x4a (0x2b99064b54ba in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch_python.so)
frame #24: <unknown function> + 0xc70f (0x2b990734d70f in /nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/lib/libtorch.so)
frame #25: <unknown function> + 0x7dd5 (0x2b9873382dd5 in /usr/lib64/libpthread.so.0)
frame #26: clone + 0x6d (0x2b9873694ead in /usr/lib64/libc.so.6)

2020-08-05 21:20:25 | WARNING | fairseq.trainer | |===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| GPU reserved memory   |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Allocations           |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|===========================================================================|

2020-08-05 21:20:25 | WARNING | fairseq.trainer | |===========================================================================|
|                  PyTorch CUDA memory summary, device ID 1                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 1            |        cudaMalloc retries: 37        |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   14088 MB |   14457 MB |  555185 GB |  555171 GB |
|       from large pool |   13696 MB |   14065 MB |  535265 GB |  535252 GB |
|       from small pool |     392 MB |     542 MB |   19919 GB |   19919 GB |
|---------------------------------------------------------------------------|
| Active memory         |   14088 MB |   14457 MB |  555185 GB |  555171 GB |
|       from large pool |   13696 MB |   14065 MB |  535265 GB |  535252 GB |
|       from small pool |     392 MB |     542 MB |   19919 GB |   19919 GB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   14890 MB |   14950 MB |   92649 GB |   92634 GB |
|       from large pool |   14488 MB |   14488 MB |   91291 GB |   91277 GB |
|       from small pool |     402 MB |     550 MB |    1358 GB |    1357 GB |
|---------------------------------------------------------------------------|
| Non-releasable memory |     801 MB |    2085 MB |  497497 GB |  497496 GB |
|       from large pool |     791 MB |    2049 MB |  477101 GB |  477100 GB |
|       from small pool |       9 MB |      95 MB |   20395 GB |   20395 GB |
|---------------------------------------------------------------------------|
| Allocations           |    1369    |    1615    |  204541 K  |  204540 K  |
|       from large pool |     378    |     407    |   91557 K  |   91556 K  |
|       from small pool |     991    |    1374    |  112984 K  |  112983 K  |
|---------------------------------------------------------------------------|
| Active allocs         |    1369    |    1615    |  204541 K  |  204540 K  |
|       from large pool |     378    |     407    |   91557 K  |   91556 K  |
|       from small pool |     991    |    1374    |  112984 K  |  112983 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     280    |     351    |    1649 K  |    1649 K  |
|       from large pool |      79    |      85    |     954 K  |     954 K  |
|       from small pool |     201    |     275    |     695 K  |     695 K  |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      68    |     180    |  126454 K  |  126454 K  |
|       from large pool |      25    |      78    |   54307 K  |   54307 K  |
|       from small pool |      43    |     126    |   72146 K  |   72146 K  |
|===========================================================================|

2020-08-05 21:20:25 | WARNING | fairseq.trainer | attempting to recover from OOM in forward/backward pass

Iterative pruning shouldn’t, by itself, cause any OOM issues (although I imagine your network is quite big and pruning in general needs to compute and attach masks at each iteration).

I’m suspicious of the retain_graph=True bit. That’s literally not allowing the memory used to store the graph to be freed, which probably causes the memory accumulation and eventual OOM. Instead of just setting that to true, can we try to find out what’s causing that error to be raised in the first place? It might have something to do with your specific architecture or training loop, because it doesn’t happen in general.
Do you have a minimal lightweight example that reproduces that error?

Thanks for the reply.

  1. Yes it is a relatively big model and yes I think the size of the task itself may have something to do with the error. (I’ve tried the same idea to a much lighter transformer architecture and didn’t come across any OOM errors).

  2. As for the retain_graph=True part, I know it’s really a hack at this point but I’m applying torch.cuda.empty_cache() right after I use the retain_graph=True flag. To be honest, I’m not sure if this API is allowing the model to free the buffers is was clinging onto due to retain_graph=True.

  3. Like you mentioned, because for every pruning step we’re creating a new mask and attaching it to the model, do you think this might be cause the model to grow in size incrementally and cause this issue? If this is the case, do you think it would help if I use prune.remove for every pruning step? So that the masks are basically omitted?

  4. As you suggested, I’ll take a deeper look into why the trainer is complaining about calling backward twice when pruning is applied. If cuda.empty_cache is not doing what I think it’s doing then probably retain_graph=True is causing the issue.

Thank you so much. Love your work. You rock.

No, so the way it’s implemented is such that the mask keeps on getting replaced. We don’t keep on storing more and more masks with every iteration.

I’m most interested in your points 2 and 4. I don’t know exactly what torch.cuda.empty_cache() is supposed to do, but it’s not helping here, as you said. I want to understand if we can avoid setting retain_graph=True altogether and therefore potentially solve your problem (assuming that that’s what’s causing it). Are you explicitly calling .backward twice anywhere in your training loop, perhaps on two different parts of the network? If not, it might still be happening implicitly somewhere and we’d have to find out why and how. This issue is related to the mask application, I’m assuming, but I can’t yet tell in what exact way…

I’ll let you take a deeper look, but if you ever get to a super reduced version of the code that can be easily run and shows the same error, I can probably help take it apart for debugging.

I did some digging into the codebase I’m using.
The only explicit backward call is here: https://github.com/pytorch/fairseq/blob/master/fairseq/trainer.py#L425
where it’s actually just a wrapper that eventually calls https://github.com/pytorch/fairseq/blob/8aa06aa03b596de58d106d3f55ff43e2b9aa0b80/fairseq/optim/fairseq_optimizer.py#L85.

I know that besides calling backward twice, accessing certain nodes in the graphs can cause this problem but I don’t see anything out of the ordinary.

I’ve done some commenting out and noticed that when this specific code segment was executed, the error appeared. (When this block was commented out the error did not appear)
This code segment is executed right before a single training step (https://github.com/ypsoh/fairseq/blob/add_pruning_routine/fairseq/trainer.py#L498)

                prune.global_unstructured(self.modules_to_prune,
                        pruning_method=prune.L1Unstructured if self.prune_type == 'magnitude' \
                                else prune.RandomUnstructured,
                        amount=pruning_amount)

So what I did was I removed all the retain_graph=True calls and wrapped this with a with torch.no_grad(). and the error seemed to go away. The models seem to be pruning as expected also.
Do you think this is a viable solution? Or does this hack give you some insight on what might have been causing this issue?

If you prune with no_grad, I don’t think the mask application will be correctly applied and backpropagated through in the backward pass. This should be visible in the pruned weights getting updated, which is not the correct behavior. Worth checking.

Out of curiosity, does the issue happen for both global unstructured and random unstructured pruning?

Another dumb question: is there any recurrence in the model? Any fancy weight sharing? I’m asking because if a parameter is reused in different ways (for example when an LSTM is unrolled over the time-steps), then we might be having issues re-accessing the graph info after the first time. This would be a problem to pin down and address in pruning.

Yes, turns out that torch.no_grad was neither working as intended nor was it resolving the OOM issues.

  • I haven’t tried it random_unstructured yet but I’ll see if that’s any different.

Any fancy weight sharing? I’m asking because if a parameter is reused in different ways (for example when an LSTM is unrolled over the time-steps), then we might be having issues re-accessing the graph info after the first time. This would be a problem to pin down and address in pruning.

So for my configuration, I’m weight sharing the embedding tables for the encoder and the decoder. (I think this is somewhat standard practice in NMT). Of course, when I’m pruning them I only add the original module as subject to pruning (Not the “tied” one) Other than that, I don’t think there is any recurrence or weight sharing in standard transformer architecture.

Do you have some thoughts on whether how this routine might behave in a multiGPU training setting?
When I was getting the OOM error, I noticed that only one GPU was out of memory, whereas all other GPUs were basically empty. (Although during training all of them were being utilized)

2020-08-06 07:12:49 | INFO | fairseq.trainer | NOTE: Weights pruned, type: magnitude, amount: 0.009999324452827444, sparsity: 0.7000000065770643 2020-08-06 07:12:50 | WARNING | fairseq.trainer | OOM: Ran out of memory with exception: CUDA out of memory. Tried to allocate 246.00 MiB (GPU 0; 31.75 GiB total capacity; 26.00 GiB already allocated; 46

OOM on one GPU

2020-08-06 07:12:50 | WARNING | fairseq.trainer | |===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|           CUDA OOMs: 1             |       cudaMalloc retries: 6          |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   26377 MB |   26988 MB |  767743 GB |  767717 GB |
|       from large pool |   25981 MB |   26595 MB |  738567 GB |  738542 GB |
|       from small pool |     395 MB |     471 MB |   29175 GB |   29175 GB |
|---------------------------------------------------------------------------|
| Active memory         |   26377 MB |   26988 MB |  767743 GB |  767717 GB |
|       from large pool |   25981 MB |   26595 MB |  738567 GB |  738542 GB |
|       from small pool |     395 MB |     471 MB |   29175 GB |   29175 GB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |   28544 MB |   28648 MB |   32826 MB |    4282 MB |
|       from large pool |   28138 MB |   28176 MB |   32024 MB |    3886 MB |
|       from small pool |     406 MB |     478 MB |     802 MB |     396 MB |
|---------------------------------------------------------------------------|
| Non-releasable memory |    1702 MB |    2343 MB |  670567 GB |  670565 GB |
|       from large pool |    1692 MB |    2325 MB |  641131 GB |  641129 GB |
|       from small pool |      10 MB |      74 MB |   29435 GB |   29435 GB |
|---------------------------------------------------------------------------|
| Allocations           |    1431    |    1533    |  296924 K  |  296923 K  |
|       from large pool |     420    |     434    |  133413 K  |  133413 K  |
|       from small pool |    1011    |    1299    |  163510 K  |  163509 K  |
|---------------------------------------------------------------------------|
| Active allocs         |    1431    |    1533    |  296924 K  |  296923 K  |
|       from large pool |     420    |     434    |  133413 K  |  133413 K  |
|       from small pool |    1011    |    1299    |  163510 K  |  163509 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |     343    |     398    |     616    |     273    |
|       from large pool |     140    |     159    |     215    |      75    |
|       from small pool |     203    |     239    |     401    |     198    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |     154    |     237    |  195998 K  |  195998 K  |
|       from large pool |      88    |     147    |   96157 K  |   96157 K  |
|       from small pool |      66    |     124    |   99841 K  |   99841 K  |
|===========================================================================|

Nothing on other GPUs

2020-08-06 07:12:50 | WARNING | fairseq.trainer | |===========================================================================|
|                  PyTorch CUDA memory summary, device ID 2                 |
|---------------------------------------------------------------------------|
|           CUDA OOMs: 0             |       cudaMalloc retries: 0          |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| GPU reserved memory   |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Allocations           |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Active allocs         |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| GPU reserved segments |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |       0    |       0    |       0    |       0    |
|       from large pool |       0    |       0    |       0    |       0    |
|       from small pool |       0    |       0    |       0    |       0    |
|===========================================================================|

So a bit of an update. I fixed the original issue without using retrain_graph=True. So no hacks involved.

I know that you previously said:

No, so the way it’s implemented is such that the mask keeps on getting replaced. We don’t keep on storing more and more masks with every iteration.

However, by the time I apply iterative pruning (approx. 100 times) an OOM still occurs. The message looks like this and seems more directly related to pruning:


-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/nfs_home/sohyongs/Workspace/fairseq/fairseq/distributed_utils.py", line 156, in distributed_main
    main(args, **kwargs)
  File "/nfs_home/sohyongs/Workspace/fairseq/fairseq_cli/train.py", line 121, in main
    valid_losses, should_stop = train(args, trainer, task, epoch_itr)
  File "/nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "/nfs_home/sohyongs/Workspace/fairseq/fairseq_cli/train.py", line 217, in train
    log_output = trainer.train_step(samples)
  File "/nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/contextlib.py", line 52, in inner
    return func(*args, **kwds)
  File "/nfs_home/sohyongs/Workspace/fairseq/fairseq/trainer.py", line 475, in train_step
    amount=pruning_amount)
  File "/nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/nn/utils/prune.py", line 1049, in global_unstructured
    final_mask = container.compute_mask(t, default_mask)
  File "/nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/nn/utils/prune.py", line 391, in compute_mask
    mask = _combine_masks(method, t, default_mask)
  File "/nfs_home/sohyongs/anaconda3/envs/fairseq/lib/python3.6/site-packages/torch/nn/utils/prune.py", line 386, in _combine_masks
    new_mask[slc] = partial_mask.to(dtype=new_mask.dtype)
R

Maybe the _combine_masks part computation is requiring too much intermediate memory? Is it better practice to apply pruning on the CPU and then moving it to the device?
Due to it being more directly related to pruning, can you share any suspicions? @Michela Thanks!

My suspicion is that global pruning requires an operation that consumes lots of memory. I changed it to normal module based pruning and the OOM doesn’t occur.

Hi @ypsoh, I have exactly the same problem. Could you tell me how you fixed it without setting retain_graph=True?

Hi @ypsoh, I also met the OOM problem with global pruning. This is strange but the GPU free memory does seem to drop after I call global pruning (not every time). How did you manage to change it to normal pruning for all the parameters?

@Michela
Here is a minimal example of iteratively calling global_pruning and training the model for a few steps after pruning: https://colab.research.google.com/drive/14t8E_17NuOY46iOwuB_TDTVUn-SVvh1a?usp=sharing

You can see the used GPU memory increases after each prune.

If I remove the training steps, the memory does not increase. So I guess the problem comes from training the model? But why this is a problem.

Hello, I’m having the exact same issue with fairseq. Did you guys @ypsoh @Edi_Chan @Edi_Chan @Michela manage to fix the retain_graph=True issue? How? Many thanks!