How to debug causes of GPU memory leaks?

Hi,

I have been trying to figure out why my code crashes after several batches because of cuda memory error. I understand that probably there is some variable(s) that is not freed because I keep it in the graph. The question is how to debug that kind of thing?

I have wrote a simple line profiler that examines amount of GPU memory on each step, and it seems like on each loss.backwards() step there was massive memory allocation and not all of it frees, it that piles up for a couple of absolutely identical iterations leading to crash. It always crashes on backwards(). I understand that pytorch reuses memory and that is why it might seem like it is not freeing memory, but here is seems like something is indeed leaking.

I tried explicitly del’ing variables (e.g. with data) to mark them for reallocation by torch, no help. I also used CUDA_LAUNCH_BLOCKING=1 to force it to execute things “in place”.

[GPU memory trace]

Is there a way to get a memory footprint like “all tensors allocated on GPU”?

Changes are that it is not backwards() fault, e.g. something like
a +10Mb - 30 is free
b +10Mb - 20 is
c +20Mb - 0 is free
clean, b leaks and was not marked as free, 30 freed and not releazed
a +10Mb - no allocation (20 is actually free)
b +10Mb - no allocation, (10 is actually free)
c +20Mb - new large allocation, breaks, even though b is leaking

Anyway, how to track these things?

34 Likes

Great question, hope someone answers. This will be useful to me too!

3 Likes

In python, you can use the garbage collector’s book-keeping to print out the currently resident Tensors. Here’s a snippet that shows all the currently allocated Tensors:

# prints currently alive Tensors and Variables
import torch
import gc
for obj in gc.get_objects():
    try:
        if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass

Edited on April 7th, 2019 to add a try/except

56 Likes

Thanks! Seems to work with a try: except block around it (some objects like shared libraries throw exception when you try to do hasattr on them).

I extended my code that tracked memory usage to also track where memory allocations appeared by comparing set of tensors before and after operation. And results are somewhat surprising. I stopped execution after first batch (it breaks on gpu memory allocation on second batch) and memory consumption was higher in the case where less tensors were allocated O_o

Here is the diff between sorted two things below, i.e. the one that has more allocations ended up having less memory consumed. One possible explanation is that in model that requested more memory, single model was applied twice and gradient was computed, whereas in the “less consuming”, the trained model is applied once, and then a fixed model is applied ones.

< + start freeze_params:91                             (128, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (128, 64, 3, 3)      <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (128,)               <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (21, 21, 32, 32)     <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (21, 21, 4, 4)       <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (21, 4096, 1, 1)     <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (21, 512, 1, 1)      <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (21,)                <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (256, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (256, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (256,)               <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (4096, 4096, 1, 1)   <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (4096, 512, 7, 7)    <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (4096,)              <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (512, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (512, 512, 3, 3)     <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (512,)               <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (64, 3, 3, 3)        <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (64, 64, 3, 3)       <class 'torch.cuda.FloatTensor'>
< + start freeze_params:91                             (64,)                <class 'torch.cuda.FloatTensor'>

Here are original logs

:7938.9 Mb

+ __main__ match_source_target:174                   (1, 3, 1052, 1914)   <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:178                   (1, 3, 1052, 1914)   <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (128, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (128, 64, 3, 3)      <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (128,)               <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (21, 4096, 1, 1)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (21,)                <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (256, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (256, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (256,)               <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (4096, 4096, 1, 1)   <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (4096, 512, 7, 7)    <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (4096,)              <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (512, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (512, 512, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (512,)               <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (64, 3, 3, 3)        <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (64, 64, 3, 3)       <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (64,)                <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (128, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (128, 64, 3, 3)      <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (128,)               <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21, 21, 32, 32)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21, 21, 4, 4)       <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21, 4096, 1, 1)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21, 512, 1, 1)      <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21,)                <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (256, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (256, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (256,)               <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (4096, 4096, 1, 1)   <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (4096, 512, 7, 7)    <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (4096,)              <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (512, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (512, 512, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (512,)               <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (64, 3, 3, 3)        <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (64, 64, 3, 3)       <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (64,)                <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (1, 500)             <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (1,)                 <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (500, 21)            <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (500, 500)           <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (500,)               <class 'torch.cuda.FloatTensor'>
+ distances.mlp objective:26                         (2,)                 <class 'torch.cuda.FloatTensor'>
+ distances.mlp objective:33                         (1,)                 <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (1, 500)             <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (1,)                 <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (500, 21)            <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (500, 500)           <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (500,)               <class 'torch.cuda.FloatTensor'>
+ model.fcn16 __init__:13                            (21,)                <class 'torch.cuda.FloatTensor'>
+ model.fcn16 features_at:45                         (1, 4096, 34, 60)    <class 'torch.cuda.FloatTensor'>
+ model.fcn16 features_at:48                         (1, 4096, 34, 60)    <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (128, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (128, 64, 3, 3)      <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (128,)               <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (21, 21, 32, 32)     <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (21, 21, 4, 4)       <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (21, 4096, 1, 1)     <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (21, 512, 1, 1)      <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (21,)                <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (256, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (256, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (256,)               <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (4096, 4096, 1, 1)   <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (4096, 512, 7, 7)    <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (4096,)              <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (512, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (512, 512, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (512,)               <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (64, 3, 3, 3)        <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (64, 64, 3, 3)       <class 'torch.cuda.FloatTensor'>
+ start freeze_params:91                             (64,)                <class 'torch.cuda.FloatTensor'>

and in the second case - it has less tensors (i.e. identical to above except ~10 tensors less), but higher memory consumption and breaks on second epoch. Any ides?

:11820.9Mb 

+ __main__ match_source_target:174                   (1, 3, 1052, 1914)   <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:178                   (1, 3, 1052, 1914)   <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (128, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (128, 64, 3, 3)      <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (128,)               <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (21, 4096, 1, 1)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (21,)                <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (256, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (256, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (256,)               <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (4096, 4096, 1, 1)   <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (4096, 512, 7, 7)    <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (4096,)              <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (512, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (512, 512, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (512,)               <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (64, 3, 3, 3)        <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (64, 64, 3, 3)       <class 'torch.cuda.FloatTensor'>
+ __main__ match_source_target:190                   (64,)                <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (128, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (128, 64, 3, 3)      <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (128,)               <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21, 21, 32, 32)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21, 21, 4, 4)       <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21, 4096, 1, 1)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21, 512, 1, 1)      <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (21,)                <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (256, 128, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (256, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (256,)               <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (4096, 4096, 1, 1)   <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (4096, 512, 7, 7)    <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (4096,)              <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (512, 256, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (512, 512, 3, 3)     <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (512,)               <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (64, 3, 3, 3)        <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (64, 64, 3, 3)       <class 'torch.cuda.FloatTensor'>
+ __main__ run_adaptation:355                        (64,)                <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (1, 500)             <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (1,)                 <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (500, 21)            <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (500, 500)           <class 'torch.cuda.FloatTensor'>
+ distances.mlp _init_net:55                         (500,)               <class 'torch.cuda.FloatTensor'>
+ distances.mlp objective:26                         (2,)                 <class 'torch.cuda.FloatTensor'>
+ distances.mlp objective:33                         (1,)                 <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (1, 500)             <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (1,)                 <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (500, 21)            <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (500, 500)           <class 'torch.cuda.FloatTensor'>
+ distances.mlp_base attempt_update_d:77             (500,)               <class 'torch.cuda.FloatTensor'>
+ model.fcn16 __init__:13                            (21,)                <class 'torch.cuda.FloatTensor'>
+ model.fcn16 features_at:45                         (1, 4096, 34, 60)    <class 'torch.cuda.FloatTensor'>
+ model.fcn16 features_at:48                         (1, 4096, 34, 60)    <class 'torch.cuda.FloatTensor'>

@Ben_Usman May I ask what did you use to generate your GPU memory trace? Perhaps python’s trace with added py3nvml calls?

1 Like

To comment on your question, do you use variable-sized batches as input? In that case, that might be caused by memory fragmentation (storages need to be re-allocated)…

1 Like

Thank you for a reply! Yes, that was set_trace with the following trace function.

I must have figured out the source of the leak by the way. It was due to the fact that significant portion of the code like variable allocation and intermediate computations was located within a single python function scope, so I suspect that those intermediate variable were not marked as free even though they were not used anywhere further. Putting a lot of del's kind of helped, but just isolating each individual step of computation into a separate function call so that all intermediate variable are automatically freed in the end of scope seems to be a better solution. Does that sound reasonable in context of pytorch?

I wonder if there could be some sort of “semantic garbage collection” that could detect that variable is not used anywhere further, thus could be freed.

Thanks!

16 Likes

Thanks for showing your code!

I’m happy you’ve resolved your memory issue - it’s a very useful observation you’ve made and it’s good it’s now here in public. Indeed, Python’s lack of block scoping can sometimes delay object destruction unnecessarily long. Actually, I’ve used dels myself recently for releasing buffers at the end of each iteration in a loop processing variable-sized data. I’m afraid that “semantic garbage collection” is not technically possible because the language is dynamic.

1 Like

Hello! I am trying to use this technique to debug, but the amount of GPU memory used seems to be an order of magnitude larger than the tensors being allocated.

After running a forward pass on my network I use the code above:

for obj in gc.get_objects():
    if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
        print(reduce(op.mul, obj.size()) if len(obj.size()) > 0 else 0, type(obj), obj.size())

GPU Mem used is around 10GB after a couple of forward/backward passes.

(161280, <class 'torch.autograd.variable.Variable'>, (5, 14, 3, 24, 32))
(451584, <class 'torch.autograd.variable.Variable'>, (14, 14, 3, 24, 32))
(612864, <class 'torch.autograd.variable.Variable'>, (19, 14, 3, 24, 32))
(612864, <class 'torch.autograd.variable.Variable'>, (19, 14, 3, 24, 32))
(2, <class 'torch.autograd.variable.Variable'>, (2,))
(420, <class 'torch.autograd.variable.Variable'>, (30, 1, 14))
(1026000, <class 'torch.autograd.variable.Variable'>, (19, 15, 450, 8))
(202, <class 'torch.autograd.variable.Variable'>, (2, 101))
(0, <class 'torch.autograd.variable.Variable'>, ())
(3, <class 'torch.autograd.variable.Variable'>, (3,))
(70, <class 'torch.autograd.variable.Variable'>, (5, 14))
(45, <class 'torch.autograd.variable.Variable'>, (45,))
(13230, <class 'torch.autograd.variable.Variable'>, (90, 3, 7, 7))
(90, <class 'torch.autograd.variable.Variable'>, (90,))
(10, <class 'torch.autograd.variable.Variable'>, (10,))
(735, <class 'torch.autograd.variable.Variable'>, (15, 1, 7, 7))
(15, <class 'torch.autograd.variable.Variable'>, (15,))
(8, <class 'torch.autograd.variable.Variable'>, (1, 1, 1, 8, 1))
(808, <class 'torch.autograd.variable.Variable'>, (101, 8))
(101, <class 'torch.autograd.variable.Variable'>, (101,))
(3, <class 'torch.autograd.variable.Variable'>, (3,))
(70, <class 'torch.autograd.variable.Variable'>, (5, 14))
(45, <class 'torch.autograd.variable.Variable'>, (45,))
(13230, <class 'torch.autograd.variable.Variable'>, (90, 3, 7, 7))
(90, <class 'torch.autograd.variable.Variable'>, (90,))
(10, <class 'torch.autograd.variable.Variable'>, (10,))
(735, <class 'torch.autograd.variable.Variable'>, (15, 1, 7, 7))
(15, <class 'torch.autograd.variable.Variable'>, (15,))
(8, <class 'torch.autograd.variable.Variable'>, (1, 1, 1, 8, 1))
(808, <class 'torch.autograd.variable.Variable'>, (101, 8))
(101, <class 'torch.autograd.variable.Variable'>, (101,))

You can see the biggest variable here should only total in at around 10MB, and altogether, they shouldn’t need much more space than this. Where is the hidden memory usage? My batch size IS variable, as referenced above, but I OOM after only a few batches. Am I confused about something, or is it using 10-100x more memory than it should?

Thanks for any insight!

@Ben_Usman. Can share how to use gpu_profile.py ? Thank you!

Something to consider with variable sized batches is that pytorch allocates memory for the batches as needed and doesn’t free them inline because the cost of calling garbage collection during the training loop is too high. With variable batch sizes this can lead to multiple instances of the same buffer for the batch in memory.

If you make sure that your variably sized batches start with the largest batch then the initial memory allocated will be large enough to hold all batches and you won’t have crazy memory growth. The natural instinct of most programmers is to do the opposite if they’re ordering, which means that the same buffer gets allocated multiple times over the course of training and never gets freed. Even if it’s random there’s still a lot of unnecessary allocation going on.

I ran into this with a language model with a random backprop through time window in it’s batching and was able to reduce the memory requirements by an order magnitude by forcing the first batch to be the largest.

19 Likes

Cool! Your fix works. You saved my day (almost)!
I’ve encountered “out of memory” crash using caffe to extract features with pretrained resnet. So I rewrote the code in pytorch and still met this error. After wrapping the extraction part in a function, it went much further: from 400+ images to 1200+ images. However, in the end it still ran out of memory.
I was watching the GPU RAM usage closely, and saw the RAM increased from 4G to 9G and then 11G before it crashed (the total GPU RAM is 12G).
Really confused why this would happen. Anyway I’m just using some standard model to do minimal processing.

Update:
I managed to reduce RAM use by wrapping the extraction code in a no_grad block:

with torch.no_grad():
    features = model(im).data

Now the code runs through all the 15000+ images without any error :smile:
Though, I still don’t understand why keeping the gradients will cause the memory problem, as I’m looping through each individual image, and in the next loop, the gradients should be dereferenced and released automatically.

Update 2:
Finally I solved the memory problem! I realized that in each iteration I put the input data in a new tensor, and pytorch generates a new computation graph. That causes the used RAM to grow forever. Then I use a placeholder tensor and copy the data to this tensor, and the RAM always stays at a low level :smile:

13 Likes

I have a Trainer which wraps all my training code with model initialization, dataset, optimizers etc. So to do hyperparameter search I will initialize my trainer within a for loop e.g.

for _ in range(10):
    trainer.initialize()
    trainer.fit()
    recorders.append(trainer.recorder)

Though I can see that the VRAM usage slowly increases after each trainer.initialize()

To fix this I though that I would add:

# prints currently alive Tensors and Variables
import torch
import gc
for obj in gc.get_objects():
    if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
        del obj
torch.cuda.empty_cache()

at the end of every training loop, though it still doesn’t work and my VRAM continues to increase after every iteration regardless.

Are the tensors found in gc.get_objects() all the alive tensors? Or are there others hiding somewhere else?

Did you write the trainer class yourself or are you using some other API?
In any case, could you post or link to the code of initialize() and fit()?
What exactly is trainer.recorder? Could it be its somehow holding a reference to the computation graph?

Hi thanks for the prompt reply!
The trainer class is one that I wrote myself, but I believe I found the culprit, the recorder class was holding a reference to the computation graph. Though I thought that by deleting all tensors, wouldn’t that also delete the reference to the computation graph?

1 Like

Good to hear you’ve found the bug! :slight_smile:

I’m not sure if deleting always works without any shortcomings. I’ve never used it as this seems to be kind of a hack.

Just in case anyone is still facing this issue, I changed @Ben_Usman code snippet to actually debug only specific functions, and also to clear the GPU cache periodically to analyze how much memory is used.

import os
import gc
import torch
import datetime

from py3nvml import py3nvml

PRINT_TENSOR_SIZES = True
# clears GPU cache frequently, showing only actual memory usage
EMPTY_CACHE = True
gpu_profile_fn = (f"{datetime.datetime.now():%d-%b-%y-%H:%M:%S}"
                  f"-gpu_mem_prof.txt")
if 'GPU_DEBUG' in os.environ:
    print('profiling gpu usage to ', gpu_profile_fn)

_last_tensor_sizes = set()


def _trace_lines(frame, event, arg):
    if event != 'line':
        return
    if EMPTY_CACHE:
        torch.cuda.empty_cache()
    co = frame.f_code
    func_name = co.co_name
    line_no = frame.f_lineno
    filename = co.co_filename
    py3nvml.nvmlInit()
    mem_used = _get_gpu_mem_used()
    where_str = f"{func_name} in {filename}:{line_no}"
    with open(gpu_profile_fn, 'a+') as f:
        f.write(f"{where_str} --> {mem_used:<7.1f}Mb\n")
        if PRINT_TENSOR_SIZES:
            _print_tensors(f, where_str)

    py3nvml.nvmlShutdown()


def trace_calls(frame, event, arg):
    if event != 'call':
        return
    co = frame.f_code
    func_name = co.co_name

    try:
        trace_into = str(os.environ['TRACE_INTO'])
    except:
        print(os.environ)
        exit()
    if func_name in trace_into.split(' '):
        return _trace_lines
    return


def _get_gpu_mem_used():
    handle = py3nvml.nvmlDeviceGetHandleByIndex(
        int(os.environ['GPU_DEBUG']))
    meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle)
    return meminfo.used/1024**2


def _print_tensors(f, where_str):
    global _last_tensor_sizes
    for tensor in _get_tensors():
        if not hasattr(tensor, 'dbg_alloc_where'):
            tensor.dbg_alloc_where = where_str
    new_tensor_sizes = {(x.type(), tuple(x.shape), x.dbg_alloc_where)
                        for x in _get_tensors()}
    for t, s, loc in new_tensor_sizes - _last_tensor_sizes:
        f.write(f'+ {loc:<50} {str(s):<20} {str(t):<10}\n')
    for t, s, loc in _last_tensor_sizes - new_tensor_sizes:
        f.write(f'- {loc:<50} {str(s):<20} {str(t):<10}\n')
    _last_tensor_sizes = new_tensor_sizes


def _get_tensors(gpu_only=True):
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                tensor = obj
            elif hasattr(obj, 'data') and torch.is_tensor(obj.data):
                tensor = obj.data
            else:
                continue

            if tensor.is_cuda:
                yield tensor
        except Exception as e:
            pass

To setup the profiler:

        import sys
        from gpu_profile import trace_calls
        os.environ['GPU_DEBUG'] = args.dev
        os.environ['TRACE_INTO'] = 'train_epoch'
        sys.settrace(trace_calls)
6 Likes

@smth I think that your method for finding all the tensor via Python’s garbage collector does not account for all tensors. I suppose that a corner case is for the backpropagation, when some tensor might be saved for the backward pass in a context and transformed (probably compressed in some way), hence they do not appear as tensors anymore. I wrote a method to account for the saved_tensors in the context for the backward pass. Could you please check if it extracts all the saved tensors correctly?

def get_tensors(only_cuda=False, omit_objs=[]):
    """

    :return: list of active PyTorch tensors
    >>> import torch
    >>> from torch import tensor
    >>> clean_gc_return = map((lambda obj: del_object(obj)), gc.get_objects())
    >>> device = "cuda" if torch.cuda.is_available() else "cpu"
    >>> device = torch.device(device)
    >>> only_cuda = True if torch.cuda.is_available() else False
    >>> t1 = tensor([1], device=device)
    >>> a3 = tensor([[1, 2], [3, 4]], device=device)
    >>> # print(get_all_tensor_names())
    >>> tensors = [tensor_obj for tensor_obj in get_tensors(only_cuda=only_cuda)]
    >>> # print(tensors)
    >>> # We doubled each t1, a3 tensors because of the tensors collection.
    >>> expected_tensor_length = 2
    >>> assert len(tensors) == expected_tensor_length, f"Expected length of tensors {expected_tensor_length}, but got {len(tensors)}, the tensors: {tensors}"
    >>> exp_size = (2,2)
    >>> act_size = tensors[1].size()
    >>> assert exp_size == act_size, f"Expected size {exp_size} but got: {act_size}"
    >>> del t1
    >>> del a3
    >>> clean_gc_return = map((lambda obj: del_object(obj)), tensors)
    """
    add_all_tensors = False if only_cuda is True else True
    # To avoid counting the same tensor twice, create a dictionary of tensors,
    # each one identified by its id (the in memory address).
    tensors = {}

    # omit_obj_ids = [id(obj) for obj in omit_objs]

    def add_tensor(obj):
        if torch.is_tensor(obj):
            tensor = obj
        elif hasattr(obj, 'data') and torch.is_tensor(obj.data):
            tensor = obj.data
        else:
            return

        if (only_cuda and tensor.is_cuda) or add_all_tensors:
            tensors[id(tensor)] = tensor

    for obj in gc.get_objects():
        try:
            # Add the obj if it is a tensor.
            add_tensor(obj)
            # Some tensors are "saved & hidden" for the backward pass.
            if hasattr(obj, 'saved_tensors') and (id(obj) not in omit_objs):
                for tensor_obj in obj.saved_tensors:
                    add_tensor(tensor_obj)
        except Exception as ex:
            pass
            # print("Exception: ", ex)
            # logger.debug(f"Exception: {str(ex)}")
    return tensors.values()  # return a list of detected tensors

@Adam_Dziedzic If I remember, saved_tensors will only be triggered on obj for the functions in python land or functions that are directly alive. For autograd functions that are not alive anymore in python (but are alive because another Python object refers to them as part of grad_fn chain), those wont show up.

@smth Where does the rest of the memory live? I have an example where walking the gc objects as above gives me a number less than half of the value returned by torch.cuda.memory_allocated(). In my case, the gc object approach gives me about 1.1GB and torch.cuda.memory_allocated() returned 2.8GB.

Where is the rest hiding? This doesn’t seem like it would be simple pytorch bookkeeping overhead.