Looping over Tensor leaks memory?

Hello,

The following code generates 30 million values, which occupy 30’000’000*8 / 1024^2 ~ 229MB. My machine has 16 GB and runs out of memory instantly!

import torch

data = torch.rand(30000000)

for item in data:
    pass

I didn’t find an open issue and I didn’t create one.

I’m on arch linux and work in a virtual environment using:

Python 3.7.12 (default, Dec 2 2021, 11:47:57)
[GCC 11.1.0] on linux

and

$ pip list
Package            Version   Editable project location
------------------ --------- ----------------------------------
astropy            4.3.1
beautifulsoup4     4.10.0
cached-property    1.5.2
certifi            2021.10.8
cffi               1.15.0
charset-normalizer 2.0.7
cryptography       35.0.0
cycler             0.10.0
Cython             0.29.24
decorator          5.1.0
gwdatafind         1.0.4
h5py               3.6.0
idna               3.3
importlib-metadata 4.8.1
Jinja2             3.0.2
kiwisolver         1.3.2
lalsuite           7.0
ligo-segments      1.4.0
lscsoft-glue       2.0.0
Mako               1.1.5
MarkupSafe         2.0.1
matplotlib         3.4.3
mpi4py             3.1.3
mpld3              0.5.5
numpy              1.21.2
Pillow             8.4.0
pip                21.3.1
PyCBC              0.0a8210  /home/pascal/venv_thesis/src/pycbc
pycparser          2.20
pyerfa             2.0.0
pyOpenSSL          21.0.0
pyparsing          2.4.7
python-dateutil    2.8.2
requests           2.26.0
scipy              1.7.1
setuptools         60.0.0
six                1.16.0
soupsieve          2.2.1
torch              1.9.1
tqdm               4.62.3
typing-extensions  3.10.0.2
urllib3            1.26.7
zipp               3.6.0

Hi Fancy!

I can reproduce this on a 32 GB ubuntu machine by doubling the size
of data:

data = torch.rand(60000000)
>>> import torch
>>> torch.__version__
'1.10.0'
>>> data = torch.rand(60000000)
>>> for item in data:
...     pass
...
Killed
Python 3.8.3 (default, May 19 2020, 18:47:26)

top shows memory (%MEM) increasing up to 100% before the python
process is killed. This happens after a little more than half a minute.

Note, this does not happen with a numpy array:

data = np.random.rand (60000000)

(In this case the loop takes a few seconds to complete.)

Best.

K. Frank

Hi Fancy!

Furthermore, it looks like the full time and memory are taken up by
creating the iterator (or whatever implements for item in data:)
rather than by actually executing the loop:

>>> import torch
>>> torch.__version__
'1.10.0'
>>> data = torch.rand (30000000)
>>> for  item in data:
...     break
...
>>> data = torch.rand (60000000)
>>> for  item in data:
...     break
...
Killed

Best.

K. Frank

I was also able to reproduce the issue and came to the same conclusion as my print statements inside the loop were not shown while the memory increased massively in torch==1.11.0.dev20220108.
@pascalm would you mind creating a GitHub issue so that we could track and fix this bug, please?

Thanks for pointing that out. I reduced it to

import torch
data = torch.rand (60000000)
it = iter(data)

I’ll create a github issue.