Hi all,
I’m currently working on training a model on a rather large dataset ( ~1 billion examples in total). To store and load this dataset, I’m employing the Webdataset library (https://github.com/webdataset/webdataset
). I am storing preprocessed examples as Pytorch tensors in a tar file, per the Webdataset spec. I am running into a memory issue after a long period of training where the machine performing the reading/loading operation runs out of memory. After a bit of debugging, it seems to be caused by some unexpected (at least to me) behavior with the Pytorch serialization code.
Since each tensor is stored as a separate file, then tarred together, each must be loaded individually, meaning that ~1 billion torch.load() operations take place per epoch. I traced memory usage using the tracemalloc
utility and discovered that some part of this serialization process is creating objects that are not garbage collected even after the tensor read from disk goes out of scope (or in the example below, is manually deleted).
I have included a minimal example below to show these objects. The example creates a dummy tensor, then reads it into memory repeatedly. The memory usage is queried in 10,000 step intervals. Given that that the tensor goes out of scope immediately after it is read with torch.load(), I would not expect to see any objects from torch/serialization.py
to still be in scope. However, I instead see ~10,000 new objects created each time the memory is queried (every 10,000 steps, see example output below). These objects also seem to never be released and cause constant linear memory growth.
Is there something I am missing here? Is there a way to remove references to these objects so they can be collected? Even though they are only 28B each, on average, they still seem to be causing memory use to grow linearly until the machine is out of memory (usually after many millions of examples are read from disk on a high-mem training machine).
Thank you in advance for any and all help you can offer!
Best,
Tyler
System info
- Ubuntu 18.04.1
- Python 3.8
- Pytorch 1.9.0
Minimal reproducible example:
import torch
import tracemalloc
if __name__ == '__main__':
# Create a dummy Tensor serialized to disk
rand = torch.rand((20, 10))
torch.save(rand, 'test.pth')
# Start tracemalloc
tracemalloc.start(30)
old_snapshot = tracemalloc.take_snapshot()
for i in range(100000):
# Load the Tensor,
test_tensor = torch.load('test.pth')
# Immediately delete reference to test_tensor
del test_tensor
if i % 10000 == 0 and i != 0:
# Take snapshot
snapshot = tracemalloc.take_snapshot()
# Print changes in memory consumption
print(f'################# STEP {i} #################')
for stat in snapshot.compare_to(old_snapshot, 'lineno')[:2]:
print(str(stat))
print('############################################')
# Save snapshot
old_snapshot = snapshot
Example output:
################# STEP 10000 #################
.../python3.8/site-packages/torch/serialization.py:845: size=274 KiB (+274 KiB), count=10005 (+10005), average=28 B
.../python3.8/site-packages/torch/serialization.py:242: size=274 KiB (+274 KiB), count=10003 (+10003), average=28 B
################# STEP 20000 #################
.../python3.8/site-packages/torch/serialization.py:845: size=547 KiB (+273 KiB), count=20005 (+10000), average=28 B
.../python3.8/site-packages/torch/serialization.py:242: size=547 KiB (+273 KiB), count=20003 (+10000), average=28 B
################# STEP 30000 #################
.../python3.8/site-packages/torch/serialization.py:242: size=820 KiB (+273 KiB), count=30003 (+10000), average=28 B
.../python3.8/site-packages/torch/serialization.py:845: size=821 KiB (+273 KiB), count=30004 (+9999), average=28 B
The two lines being referenced from torch/serialization.py
above are:
and