Uncollected object references from torch.load() cause memory growth

Hi all,

I’m currently working on training a model on a rather large dataset ( ~1 billion examples in total). To store and load this dataset, I’m employing the Webdataset library (https://github.com/webdataset/webdataset). I am storing preprocessed examples as Pytorch tensors in a tar file, per the Webdataset spec. I am running into a memory issue after a long period of training where the machine performing the reading/loading operation runs out of memory. After a bit of debugging, it seems to be caused by some unexpected (at least to me) behavior with the Pytorch serialization code.

Since each tensor is stored as a separate file, then tarred together, each must be loaded individually, meaning that ~1 billion torch.load() operations take place per epoch. I traced memory usage using the tracemalloc utility and discovered that some part of this serialization process is creating objects that are not garbage collected even after the tensor read from disk goes out of scope (or in the example below, is manually deleted).

I have included a minimal example below to show these objects. The example creates a dummy tensor, then reads it into memory repeatedly. The memory usage is queried in 10,000 step intervals. Given that that the tensor goes out of scope immediately after it is read with torch.load(), I would not expect to see any objects from torch/serialization.py to still be in scope. However, I instead see ~10,000 new objects created each time the memory is queried (every 10,000 steps, see example output below). These objects also seem to never be released and cause constant linear memory growth.

Is there something I am missing here? Is there a way to remove references to these objects so they can be collected? Even though they are only 28B each, on average, they still seem to be causing memory use to grow linearly until the machine is out of memory (usually after many millions of examples are read from disk on a high-mem training machine).

Thank you in advance for any and all help you can offer!

Best,
Tyler

System info

  • Ubuntu 18.04.1
  • Python 3.8
  • Pytorch 1.9.0

Minimal reproducible example:

import torch
import tracemalloc

if __name__ == '__main__':
    # Create a dummy Tensor serialized to disk
    rand = torch.rand((20, 10))
    torch.save(rand, 'test.pth')

    # Start tracemalloc
    tracemalloc.start(30)
    old_snapshot = tracemalloc.take_snapshot()

    for i in range(100000):
        # Load the Tensor,
        test_tensor = torch.load('test.pth')

        # Immediately delete reference to test_tensor
        del test_tensor

        if i % 10000 == 0 and i != 0:
            # Take snapshot
            snapshot = tracemalloc.take_snapshot()

            # Print changes in memory consumption
            print(f'################# STEP {i} #################')
            for stat in snapshot.compare_to(old_snapshot, 'lineno')[:2]:
                print(str(stat))
            print('############################################')

            # Save snapshot
            old_snapshot = snapshot

Example output:

################# STEP 10000 #################
.../python3.8/site-packages/torch/serialization.py:845: size=274 KiB (+274 KiB), count=10005 (+10005), average=28 B
.../python3.8/site-packages/torch/serialization.py:242: size=274 KiB (+274 KiB), count=10003 (+10003), average=28 B
################# STEP 20000 #################
.../python3.8/site-packages/torch/serialization.py:845: size=547 KiB (+273 KiB), count=20005 (+10000), average=28 B
.../python3.8/site-packages/torch/serialization.py:242: size=547 KiB (+273 KiB), count=20003 (+10000), average=28 B
################# STEP 30000 #################
.../python3.8/site-packages/torch/serialization.py:242: size=820 KiB (+273 KiB), count=30003 (+10000), average=28 B
.../python3.8/site-packages/torch/serialization.py:845: size=821 KiB (+273 KiB), count=30004 (+9999), average=28 B

The two lines being referenced from torch/serialization.py above are:

and

del x doesn’t guarantee garbage collection in Python, you need to try with gc.collect()

Hi @VitalyFedyunin, thanks for the suggestion! Unfortunately, the issue persists even after adding an explicit call to gc.collect() after each read. I’ve included an updated script and output below.

import torch
import tracemalloc
import gc

if __name__ == '__main__':
    # Create a dummy Tensor serialized to disk
    rand = torch.rand((20, 10))
    torch.save(rand, 'test.pth')

    # Start tracemalloc
    tracemalloc.start(30)
    old_snapshot = tracemalloc.take_snapshot()

    for i in range(30001):
        # Load the Tensor
        test_tensor = torch.load('test.pth')

        # Immediately delete reference to test_tensor
        del test_tensor
        gc.collect()

        if i % 10000 == 0 and i != 0:
            # Take snapshot
            snapshot = tracemalloc.take_snapshot()

            # Print changes in memory consumption
            print(f'################# STEP {i} #################')
            for stat in snapshot.compare_to(old_snapshot, 'lineno')[:2]:
                print(str(stat))
            print('############################################')

            # Save snapshot
            old_snapshot = snapshot

Output:

################# STEP 10000 #################
.../python3.8/site-packages/torch/serialization.py:845: size=274 KiB (+274 KiB), count=10003 (+10003), average=28 B
.../python3.8/site-packages/torch/serialization.py:242: size=274 KiB (+274 KiB), count=10003 (+10003), average=28 B
############################################
################# STEP 20000 #################
.../python3.8/site-packages/torch/serialization.py:845: size=547 KiB (+273 KiB), count=20003 (+10000), average=28 B
.../python3.8/site-packages/torch/serialization.py:242: size=547 KiB (+273 KiB), count=20003 (+10000), average=28 B
############################################
################# STEP 30000 #################
.../python3.8/site-packages/torch/serialization.py:845: size=820 KiB (+273 KiB), count=30003 (+10000), average=28 B
.../python3.8/site-packages/torch/serialization.py:242: size=820 KiB (+273 KiB), count=30003 (+10000), average=28 B
############################################

I’ve also opened an issue on the Pytorch Github here: