DataLoader Memory Leak?

I have the following situation, I’m trying to train a Unet Learner using fastai’s Library. My data is stored as float16 tensor saved by using torch.save and loaded via a custom load function. In fastai, you create a Learner object, and then you call Learn.fit() to train your model. My memory usage is linearly going up during training to a point where I run out of memory. One interesting thing is that my memory usage is reset between epochs, and so the problem is coming from an epoch of training. Basically fastai iters through a pytorch dataloader and does its stuff on top of that.

My Learner item has a learn.data.train_dl.dl attribute which is a torch.utils.data.dataloader.DataLoader. When simply iterating through this dataloader with

> for xb, yb in learn.data.train_dl.dl:
>     pass

All my memory is used in a few minutes. I read on the forums that this could be coming from an issue when using multiple workers to load the files. I tried setting num_workers to 0 but this didn’t fix my memory issue.

I would really like to be able to train my model, do you have any idea on how to fix that bug, or how to do a hack (like how to replace the for x, y in dl with something else) to go around this issue?

For more details on how i’m loading the data, I’m using a fastai custom class TensorImageList(ImageList) where I just overrides the open(self, fn) method with :

class TensorImageList(ImageList):  
    def open(self, fn):
        return torch.load(fn, map_location='cpu').type(torch.float)

I’ve also opened a post on fastai forum discussing this issue : https://forums.fast.ai/t/learn-fit-one-cycle-makes-me-run-out-of-memory-on-cpu-while-i-train-on-gpu/45428

If you see the increase in memory usage during the dummy DataLoader loop, the issue might be in the Dataset and in particular how you are loading/storing the data.
Could you post the code for your Dataset? The TensorImageList class doesn’t look like it’s corresponding to Dataset, as the __getitem__ and __len__ methods are missing.

Basically i’m using the datablock api of fastai https://docs.fast.ai/data_block.html

My databunch is created by doing an TensorImageList.from_folder(), where TensorImageList is a subclass of ImageList.

I’ve created a minimal example on this repo : https://github.com/StatisticDean/fastai_memory_cpu/blob/master/test_leak.ipynb

To access the dataset from the databunch data, one needs to do data.train_dl.dl.dataset.
The type of the dataset is LabelList which is a subclass of torch.utils.data.dataset.Dataset created by fastai in data_block.py

My databunch is created by doing
data = TensorImageList.from_folder('./data_test/', extensions='.ti').split_by_rand_pct().label_from_folder().databunch(bs=8, num_workers=0)

The __getitems__ method of the LabelList is the following :

def __getitem__(self,idxs:Union[int,np.ndarray])->'LabelList':
        idxs = try_int(idxs)
        if isinstance(idxs, Integral):
            if self.item is None: x,y = self.x[idxs],self.y[idxs]
            else:                 x,y = self.item   ,0
            if self.tfms or self.tfmargs:
                x = x.apply_tfms(self.tfms, **self.tfmargs)
            if hasattr(self, 'tfms_y') and self.tfm_y and self.item is None:
                y = y.apply_tfms(self.tfms_y, **{**self.tfmargs_y, 'do_resolve':False})
            if y is None: y=0
            return x,y
        else: return self.new(self.x[idxs], self.y[idxs])

So basically, it calls the __getitem__ method of self.x which in my case is a TensorImageList, the
__getitem__ method is the following

def __getitem__(self,idxs:int)->Any:
        idxs = try_int(idxs)
        if isinstance(idxs, Integral): return self.get(idxs)
        else: return self.new(self.items[idxs], inner_df=index_row(self.inner_df, idxs))

It calls self.get which is the following :

def get(self, i):
        fn = super().get(i)
        res = self.open(fn)
        self.sizes[i] = res.size
        return res

And at this point the open is exactly what I overrided.

After some closer inspection, I noticed that the default open method returns an Image which is a class of fastai instead of returning a tensor. So I changed my

class TensorImageList(ImageList):
    def open(self, fn):
        return torch.load(fn, map_location='cpu').type(torch.float)

with

class TensorImageList(ImageList):
    def open(self, fn):
        return Image(torch.load(fn, map_location='cpu').type(torch.float))

And magic, memory is stable (at least if you just iterate through the dataloader).

So if you just load the data, you’ll see an increase in memory, while wrapping the tensor into an Image class solves this issue?
Could you post the class definition of Image?

Honestly I don’t understand where the issue is coming from either, this was just a quick fix. The error seems to be coming from fastai side though, According to sgugger on fastai’s forum :

 " Ah! I think this might be due to our  `data_collate`  default function, which collected the  `data`  inside your tensor instead of just grabbing your tensor.
Why that didn’t release memory is beyond me, but I think if you pass to the call to  `DataBunch`  the regular pytorch collate function (which is  `torch.utils.data.dataloader.default_collate` ) you won’t have a memory leak."

The Image source code is the following :

class Image(ItemBase):
    "Support applying transforms to image data in `px`."
    def __init__(self, px:Tensor):
        self._px = px
        self._logit_px=None
        self._flow=None
        self._affine_mat=None
        self.sample_kwargs = {}

    def set_sample(self, **kwargs)->'ImageBase':
        "Set parameters that control how we `grid_sample` the image after transforms are applied."
        self.sample_kwargs = kwargs
        return self

    def clone(self):
        "Mimic the behavior of torch.clone for `Image` objects."
        return self.__class__(self.px.clone())

    @property
    def shape(self)->Tuple[int,int,int]: return self._px.shape
    @property
    def size(self)->Tuple[int,int]: return self.shape[-2:]
    @property
    def device(self)->torch.device: return self._px.device

    def __repr__(self): return f'{self.__class__.__name__} {tuple(self.shape)}'
    def _repr_png_(self): return self._repr_image_format('png')
    def _repr_jpeg_(self): return self._repr_image_format('jpeg')

    def _repr_image_format(self, format_str):
        with BytesIO() as str_buffer:
            plt.imsave(str_buffer, image2np(self.px), format=format_str)
            return str_buffer.getvalue()

    def apply_tfms(self, tfms:TfmList, do_resolve:bool=True, xtra:Optional[Dict[Callable,dict]]=None,
                   size:Optional[Union[int,TensorImageSize]]=None, resize_method:ResizeMethod=None,
                   mult:int=None, padding_mode:str='reflection', mode:str='bilinear', remove_out:bool=True)->TensorImage:
        "Apply all `tfms` to the `Image`, if `do_resolve` picks value for random args."
        if not (tfms or xtra or size): return self
        tfms = listify(tfms)
        xtra = ifnone(xtra, {})
        default_rsz = ResizeMethod.SQUISH if (size is not None and is_listy(size)) else ResizeMethod.CROP
        resize_method = ifnone(resize_method, default_rsz)
        if resize_method <= 2 and size is not None: tfms = self._maybe_add_crop_pad(tfms)
        tfms = sorted(tfms, key=lambda o: o.tfm.order)
        if do_resolve: _resolve_tfms(tfms)
        x = self.clone()
        x.set_sample(padding_mode=padding_mode, mode=mode, remove_out=remove_out)
        if size is not None:
            crop_target = _get_crop_target(size, mult=mult)
            if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
                target = _get_resize_target(x, crop_target, do_crop=(resize_method==ResizeMethod.CROP))
                x.resize(target)
            elif resize_method==ResizeMethod.SQUISH: x.resize((x.shape[0],) + crop_target)
        else: size = x.size
        size_tfms = [o for o in tfms if isinstance(o.tfm,TfmCrop)]
        for tfm in tfms:
            if tfm.tfm in xtra: x = tfm(x, **xtra[tfm.tfm])
            elif tfm in size_tfms:
                if resize_method in (ResizeMethod.CROP,ResizeMethod.PAD):
                    x = tfm(x, size=_get_crop_target(size,mult=mult), padding_mode=padding_mode)
            else: x = tfm(x)
        return x.refresh()

    def refresh(self)->None:
        "Apply any logit, flow, or affine transfers that have been sent to the `Image`."
        if self._logit_px is not None:
            self._px = self._logit_px.sigmoid_()
            self._logit_px = None
        if self._affine_mat is not None or self._flow is not None:
            self._px = _grid_sample(self._px, self.flow, **self.sample_kwargs)
            self.sample_kwargs = {}
            self._flow = None
        return self

    def save(self, fn:PathOrStr):
        "Save the image to `fn`."
        x = image2np(self.data*255).astype(np.uint8)
        PIL.Image.fromarray(x).save(fn)

    @property
    def px(self)->TensorImage:
        "Get the tensor pixel buffer."
        self.refresh()
        return self._px
    @px.setter
    def px(self,v:TensorImage)->None:
        "Set the pixel buffer to `v`."
        self._px=v

    @property
    def flow(self)->FlowField:
        "Access the flow-field grid after applying queued affine transforms."
        if self._flow is None:
            self._flow = _affine_grid(self.shape)
        if self._affine_mat is not None:
            self._flow = _affine_mult(self._flow,self._affine_mat)
            self._affine_mat = None
        return self._flow

    @flow.setter
    def flow(self,v:FlowField): self._flow=v

    def lighting(self, func:LightingFunc, *args:Any, **kwargs:Any):
        "Equivalent to `image = sigmoid(func(logit(image)))`."
        self.logit_px = func(self.logit_px, *args, **kwargs)
        return self

    def pixel(self, func:PixelFunc, *args, **kwargs)->'Image':
        "Equivalent to `image.px = func(image.px)`."
        self.px = func(self.px, *args, **kwargs)
        return self

    def coord(self, func:CoordFunc, *args, **kwargs)->'Image':
        "Equivalent to `image.flow = func(image.flow, image.size)`."
        self.flow = func(self.flow, *args, **kwargs)
        return self

    def affine(self, func:AffineFunc, *args, **kwargs)->'Image':
        "Equivalent to `image.affine_mat = image.affine_mat @ func()`."
        m = tensor(func(*args, **kwargs)).to(self.device)
        self.affine_mat = self.affine_mat @ m
        return self

    def resize(self, size:Union[int,TensorImageSize])->'Image':
        "Resize the image to `size`, size can be a single int."
        assert self._flow is None
        if isinstance(size, int): size=(self.shape[0], size, size)
        if tuple(size)==tuple(self.shape): return self
        self.flow = _affine_grid(size)
        return self

    @property
    def affine_mat(self)->AffineMatrix:
        "Get the affine matrix that will be applied by `refresh`."
        if self._affine_mat is None:
            self._affine_mat = torch.eye(3).to(self.device)
        return self._affine_mat
    @affine_mat.setter
    def affine_mat(self,v)->None: self._affine_mat=v

    @property
    def logit_px(self)->LogitTensorImage:
        "Get logit(image.px)."
        if self._logit_px is None: self._logit_px = logit_(self.px)
        return self._logit_px
    @logit_px.setter
    def logit_px(self,v:LogitTensorImage)->None: self._logit_px=v

    @property
    def data(self)->TensorImage:
        "Return this images pixels as a tensor."
        return self.px

    def show(self, ax:plt.Axes=None, figsize:tuple=(3,3), title:Optional[str]=None, hide_axis:bool=True,
              cmap:str=None, y:Any=None, **kwargs):
        "Show image on `ax` with `title`, using `cmap` if single-channel, overlaid with optional `y`"
        cmap = ifnone(cmap, defaults.cmap)
        ax = show_image(self, ax=ax, hide_axis=hide_axis, cmap=cmap, figsize=figsize)
        if y is not None: y.show(ax=ax, **kwargs)
        if title is not None: ax.set_title(title)

The code can be found in fastai.vision.image

1 Like

I am experiencing the same exact memory growing problem, however I am not using fastai library. If I read the image as a class attribute, and pass it onto class functions, e.g. _pad_img and _patch_extraction, the memory grows linearly with the iterations.

This is how I make my Dataset, which is next used by the Python DataLoader:

class MyDataset(Dataset):
        def __init__(self, patch_size, subdivisions, image_file):
            self.patch_size = patch_size
            self.subdivisions = subdivisions
            self.image = self._pad_img(cv2.cvtColor(cv2.imread(image_file), cv2.COLOR_BGR2RGB))
            self.coords = self._extract_patches(self.image)

        def _pad_img(self, img):
            aug = int(round(self.patch_size * (1 - 1.0 / self.subdivisions)))
            padding = ((aug, aug), (aug, aug), (0, 0))
            img_padded = np.pad(img, pad_width=padding, mode='reflect')
            return img_padded

        def _extract_patches(self, img):
            step = int(self.patch_size / self.subdivisions)
            row_range = range(0, img.shape[0] - self.patch_size + 1, step)
            col_range = range(0, img.shape[1] - self.patch_size + 1, step)
            coords = []
            for row in row_range:
                for col in col_range:
                    left = col
                    upper = row
                    right = col + self.patch_size
                    lower = row + self.patch_size
                    coords += [(left, upper, right, lower)]
            return coords

        def __len__(self):
            return len(self.coords)

        def __getitem__(self, idx):
            box = self.coords[idx]
            image = self.image[box[1]:box[3], box[0]:box[2]]
            return image

I have tried reading the image in _extract_patches. There is no memory problem that way, but the code becomes extremely slow.

Any ideas on how this can be fixed is appreciated.

1 Like

I am also experiencing this issue, whether or not I set my num_workers to 0.

I have a 1TB dataset and here is my code:

class DicomDatasetRetriever(torch.utils.data.Dataset):
    def __init__(self, df, transforms=[], mix_ratio=1, mode='val'):
        self.df_main = df.copy()
        self.mode = mode
        self.mix_ratio = mix_ratio
        if self.mode == 'val':
            self.df = self.df_main
        else:
            self.update_train_df()
        self.lut = df[['SOPInstanceUID', 'image_path']].set_index('SOPInstanceUID')
        
        if not(len(transforms)):
            self.transforms = None
        else:
            self.transforms = A.Compose(transforms)
        
        self.default_transforms = A.Compose([
            A.Normalize(0.449, 0.226),
            ToTensorV2(),
        ])

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        study = row.StudyInstanceUID
        img_id = row.SOPInstanceUID
        
        img = self.load_image(img_id)
        pe_ratio = row.r_pe_present_on_image
        target = row.pe_present_on_image

        # transforms
        if self.transforms is not None and self.mode != 'val' and (row[3] == 1 or random.random() < 0.1):
            img = self.transforms(image=img)['image']
        
        # default transformation
        img = self.default_transforms(image=img)['image']

        return {
            'img': img,
            'img_id': img_id,
            'study_id': study,
            'pe_ratio': torch.tensor([pe_ratio]).float(),
            'target': torch.tensor([target]).float()
        }

    def load_image(self, img_id):
#         img = cv2.imread(self.lut.loc[img_id, 'image_path'])
        with open(self.lut.loc[img_id, 'image_path'], 'rb') as f:
            img = JPEG.decode(f.read())
        if img is None:
            print(f"Warning while trying to load image. No file at {file_path}")
            img = np.zeros(shape=SHAPE) 
        img = img.astype(np.float32)
        img /= 255
        return img
    
    def update_train_df(self):
        df0 = self.df_main[self.df_main.pe_present_on_image==0]
        df1 = self.df_main[self.df_main.pe_present_on_image==1]
        df0 = df0.sample(frac=1)
        upto = min(int(self.mix_ratio * len(df1)), len(df0))
        self.df = pd.concat([df0.iloc[:upto],df1], axis=0)
        self.df = self.df.sample(frac=1)

class DataLoaders:

    self.train_loader = torch.utils.data.DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True,
            num_workers=self.num_workers, pin_memory=True)

There is a NEW TOOL called cstl ( GitHub - fuzihaofzh/cstl: The C++ Standard Template Library (STL) for Python. ). It wraps C++ STL containers to solve this issue. It supports multiple types including nested map, list, and set which the numpy and pytorch do not support.
Here is a simple example showing how it solves the problem:

from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import copy
import sys
import cstl
from tqdm.auto import tqdm


class DataIter(Dataset):
    def __init__(self):
        cnt = 24000000
        self.cnt = cnt
        #self.data = np.array([x for x in range(cnt)]) # Good
        #self.data = [x for x in range(cnt)] #Leaky
        #self.data = cstl.MapIntInt({i : i for i in range(24000000)})# Good
        self.data = cstl.VecInt(range(24000000)) # Good

        
    def __len__(self):
        return self.cnt

    def __getitem__(self, idx):
        data = self.data[idx]
        data = np.array([int(data)], dtype=np.int64)
        return torch.tensor(data)

train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300,
                          shuffle=True,
                          drop_last=True,
                          pin_memory=False,
                          num_workers=18)

for i, item in tqdm(enumerate(train_loader)):
    torch.cuda.empty_cache()
    if i % 1000 == 0:
        print(i)