Memory usage keeps increasing in Dataloader (in very simple code)

Hi,

I create a dataloader to load features from local files by their file pathes.
The dataloader can be simplfied as:

(P.S. All codes were tested on Pytorch 1.0.0 and Pytorch 1.0.1. Memory capacity of my machine is 256Gb)

import numpy as np
import torch
import torch.utils.data as data
import time

class MyDataSet(data.Dataset):

  def __init__(self):

    super(MyDataSet, self).__init__()

    # Assume that the self.infoset here contains the description information about the dataset.
    # Here it is a list of strings. I set it aoubt 8Gb in memory.
    # In my real project, this infoset may be 40Gb in memory.
    self.infoset = [str(i).zfill(1024) for i in range(len(self))]


  def __getitem__(self, index):

    info = self.infoset[index]  # problem is here

    items = {}
    items['features'] = self.load_feature(info)

    return items

  def load_feature(self, info):
    '''
    Load feature from files
    '''
    feature = torch.Tensor(np.ones([8, 4, 2], dtype=np.float32))

    return feature

  def __len__(self):

    return 8000000

dataset = MyDataSet()

dataloader = data.DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=16, pin_memory=True)

while True:

  for i, sample in enumerate(dataloader):

    print(i, len(dataloader))

    time.sleep(0.05) # slow down the process to see the mem-usage increasing during one epoch

During each epoch, the memory usage is about 13GB at the very beginning and keeps inscreasing and finally up to about 46Gb. Although it will decrease to 13GB at the beginning of next epoch, this problem is serious because in my real project the infoset is about 40Gb due to the large number of samples and finally leads to Out of Memory (OOM).

I have found that the problem is caused by the first line of MyDataset.__getitem__(), in the following code, infoset is just read but not used, but the same problem still happen.

class MyDataSet(data.Dataset):

  def __init__(self):

    super(MyDataSet, self).__init__()

    # Assume that the self.infoset here contains the description information about the dataset.
    # Here it is a list of strings. I set it aoubt 8Gb in memory.
    # In my real project, this infoset may be 40Gb in memory.
    self.infoset = [str(i).zfill(1024) for i in range(len(self))]


  def __getitem__(self, index):

    info = self.infoset[index]  # problem is here

    items = {}
    # items['features'] = self.load_feature(info)

    return items

  def load_feature(self, info):
    '''
    Load feature from files
    '''
    feature = torch.Tensor(np.ones([8, 4, 2], dtype=np.float32))

    return feature

  def __len__(self):

    return 8000000

dataset = MyDataSet()

dataloader = data.DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=16, pin_memory=True)

while True:

  for i, sample in enumerate(dataloader):

    print(i, len(dataloader))

    time.sleep(0.05) # slow down the process to see the mem-usage increasing during one epoch

As a contrast, when I comment the first line of MyDataset.__getitem__(), the memory usage keeps stable:

class MyDataSet(data.Dataset):

  def __init__(self):

    super(MyDataSet, self).__init__()

    # Assume that the self.infoset here contains the description information about the dataset.
    # Here it is a list of strings. I set it aoubt 8Gb in memory.
    # In my real project, this infoset may be 40Gb in memory.
    self.infoset = [str(i).zfill(1024) for i in range(len(self))]


  def __getitem__(self, index):

    # info = self.infoset[index]  # problem is here
    info = 'fake info'

    items = {}
    items['features'] = self.load_feature(info)

    return items

  def load_feature(self, info):
    '''
    Load feature from files
    '''
    feature = torch.Tensor(np.ones([8, 4, 2], dtype=np.float32))

    return feature

  def __len__(self):

    return 8000000

dataset = MyDataSet()

dataloader = data.DataLoader(dataset, batch_size=1024, shuffle=True, num_workers=16, pin_memory=True)

while True:

  for i, sample in enumerate(dataloader):

    print(i, len(dataloader))

    time.sleep(0.05) # slow down the process to see the mem-usage increasing during one epoch

Any reasons and suggestions about this problems?

1 Like

Did you manage to find a solution? Iā€™m hitting the same issue.