Dataloader resets dataset state

I’ve implemented a custom dataset which generates and then caches the data for reuse.
If I use the DataLoader with num_workers=0 the first epoch is slow, as the data is generated during this time, but later the caching works and the training proceeds fast.
With a higher number of workers, the first epoch runs faster but at each epoch after that the dataset’s cache is empty and so overall there is not much gain.

My understanding is that each worker keep getting initialized with a fresh dataset each time and hence the problem (the exact mechanism is not clear to me as I can see the dataset is only constructed once). Now I don’t expect the cache (or any data) to be shared between processes, but I would expect each worker to keep its cache throughout the training. Instead it seems like new workers are created at each epoch and so the desired state is lost. Is there a way to prevent that from happening?

2 Likes

As far as I know the Dataset will be copied to the workers in case you are using multiple workers.
This also means that you cannot modify the Dataset inplace anymore.
Here is a small example:

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.ones(10)
        
    def __getitem__(self, index):
        self.data[-(index+1)] = 2.
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=0,
    shuffle=False
)
for data in loader:
    print(data)
for data in loader:
    print(data)


dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=1,
    shuffle=False
)
for data in loader:
    print(data)
for data in loader:
    print(data)

While the data is already modified in the second loop for num_workers=0, it stays the same for num_workers=1.

I’m not sure, how you are implementing the caching, and there could still be a valid approach.
Could you share your implementation of caching?

Thanks for the quick reply.

Your code illustrates the problem.

The caching is simple, I’m just filling in a list (an attribute of my dataset class) with the data items as they become computed in __getitem__ (let me know if this is not clear enough).

At the moment I think I’ll just work around by computing all the items in __init__, though my initial thought was that I shouldn’t have to. It would be nice if new workers would not be created in the second data iteration (i.e. the last loop above), or that each new worker would get an updated copy of the dataset if new ones have to be created.

Yeah, I understand the issue and stumbled myself a few times over it.

I think one possible approach would be to use shared memory in Python e.g. with multiprocessing.Array.
You could initialize an array of your known size for the complete Dataset, fill it in the first iteration using all workers, and finally switch a flag indicating the cache/shared memory should be used in all following epochs.
I’ve created a small dummy example showing this behavior.
Currently the shared memory will be filled with torch.randn, so in this line of code you can add your heavy loading function.

import torch
from torch.utils.data import Dataset, DataLoader

import ctypes
import multiprocessing as mp

import numpy as np


class MyDataset(Dataset):
    def __init__(self):
        shared_array_base = mp.Array(ctypes.c_float, nb_samples*c*h*w)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        shared_array = shared_array.reshape(nb_samples, c, h, w)
        self.shared_array = torch.from_numpy(shared_array)
        self.use_cache = False
        
    def set_use_cache(self, use_cache):
        self.use_cache = use_cache
    
    def __getitem__(self, index):
        if not self.use_cache:
            print('Filling cache for index {}'.format(index))
            # Add your loading logic here
            self.shared_array[index] = torch.randn(c, h, w)
        x = self.shared_array[index]
        return x
    
    def __len__(self):
        return nb_samples


nb_samples, c, h, w = 10, 3, 24, 24

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=2,
    shuffle=False
)

for epoch in range(2):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, data.shape {}'.format(epoch, idx, data.shape))
        
    if epoch==0:
        loader.dataset.set_use_cache(True)

Let me know, if this would work for you.

10 Likes

Just tried it and it works.
Thanks a lot!

@ptrblck I am trying to utilize this caching method, but my pictures are now coming up blank, so I am obviously munging something somewhere. I am trying to cache images, so I am keeping things Numpy arrays, as I need to use torchvision transforms which means I can’t be a tensor before I do that pre-processing. My understanding is that, because we are pre-allocating the entire array, we can actually use all the workers we would like from the beginning.

The default use_cache is False, and I am just trying to get things to work in that state, where it places the image into the cache and then reads it back from cache. I realize that in normal flow I would set use_cache False for first epoch, to load the cache, and then afterwards can set to True since the cache should be built.

class MDataset(Dataset):
    def __init__(self, df: pd.DataFrame, imfolder: str, train: bool = True, transforms = None, meta_features = None, use_cache = False):

        self.df = df
        self.imfolder = imfolder
        self.transforms = transforms
        self.train = train
        self.meta_features = meta_features
        
        c=3
        h=param['image_size'][0]
        w=param['image_size'][0]
        
        shared_array_base = mp.Array(ctypes.c_float, len(self.df)*c*h*w)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        self.shared_array = shared_array.reshape(len(self.df), h, w, c)
        self.use_cache = False
        
    def __getitem__(self, index):
        if not self.use_cache:
            im_path = os.path.join(self.imfolder, self.df.iloc[index]['image_name'] + '.jpg')
            x = cv2.imread(im_path)
            x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
            self.shared_array[index] = x
        x = self.shared_array[index]
        
        meta = np.array(self.df.iloc[index][self.meta_features].values, dtype=np.float32)

        if self.transforms:
            x = self.transforms(x)
            
        if self.train:
            y = self.df.iloc[index]['target'].astype("float32")
            return (x, meta), y
        else:
            return (x, meta)

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache
    
    def __len__(self):
        return len(self.df)

Here is some other parts of my code so you can see what I am doing:

pytorch_dataset = MDataset(df=test_df, 
                            imfolder= f"{param['image_dir']}test", 
                            train=False,
                            transforms=transform,
                            meta_features=meta_features)
pytorch_dataloader = DataLoader(dataset=pytorch_dataset, batch_size=12, shuffle=True,  pin_memory=param['pin_memory'], num_workers=param['num_workers'])

images,meta = next(iter(pytorch_dataloader))

show_transform(torchvision.utils.make_grid(images, nrow=6), title="Random Images")

if I comment out x = self.shared_array[index] then everything works fine, as its using x from the file read from disk. But if I try to grab it from the cache, I don’t get any errors, I just get blank boxes where my image should be.

Is the code generally working for random inputs?
I just tested my code snippet and it still seems to work (at least I don’t get zero tensors).

Comparing the codes it seems you’ve just added the numpy array creation from the pd.DataFrame?

Well yes I just cutout moving from Numpy to a tensor. Its like whatever format is stored in that array is no longer a good image. In the other example you gave, you ran stack() after the data was loaded into an array but I am not sure that is necessary with this new format.

Here is what it looks like if I comment out moving it to the shared_array

tensor([[[ 0.5843,  0.5608,  0.5373,  ...,  0.5137,  0.5137,  0.5216],
         [ 0.5451,  0.5137,  0.5137,  ...,  0.5294,  0.5137,  0.5294],
         [ 0.5608,  0.5373,  0.5294,  ...,  0.5137,  0.5137,  0.5137],
         ...,
         [ 0.5294,  0.5451,  0.5529,  ...,  0.5686,  0.5765,  0.5922],
         [ 0.5216,  0.5529,  0.5608,  ...,  0.5765,  0.5843,  0.5922],
         [ 0.5451,  0.5843,  0.5922,  ...,  0.5765,  0.5922,  0.5922]],

        [[ 0.1451,  0.1216,  0.1059,  ...,  0.1294,  0.1294,  0.1373],
         [ 0.1059,  0.0824,  0.0745,  ...,  0.1451,  0.1451,  0.1451],
         [ 0.1216,  0.0980,  0.0824,  ...,  0.1373,  0.1373,  0.1373],
         ...,
         [-0.0039,  0.0118,  0.0196,  ...,  0.0824,  0.0902,  0.0902],
         [ 0.0118,  0.0431,  0.0510,  ...,  0.0902,  0.0980,  0.1059],
         [ 0.0353,  0.0745,  0.0745,  ...,  0.0902,  0.1059,  0.1059]],

        [[-0.0980, -0.1137, -0.1216,  ..., -0.1216, -0.1216, -0.1137],
         [-0.1294, -0.1608, -0.1529,  ..., -0.1059, -0.1059, -0.1059],
         [-0.1137, -0.1373, -0.1373,  ..., -0.1216, -0.1216, -0.1137],
         ...,
         [-0.2000, -0.1922, -0.1765,  ..., -0.1608, -0.1529, -0.1451],
         [-0.2000, -0.1686, -0.1373,  ..., -0.1529, -0.1451, -0.1373],
         [-0.1765, -0.1373, -0.1216,  ..., -0.1529, -0.1373, -0.1373]]])
torch.min(images[0]), torch.mean(images[0]), torch.max(images[0])
(tensor(-1.), tensor(0.2850), tensor(0.9294))

But then if I uncomment x = self.shared_array[index] I get this instead:

tensor([[[416.1446, 419.1934, 423.8561,  ..., 390.7272, 393.7298, 395.5690],
         [426.6692, 426.7228, 426.8047,  ..., 393.6453, 395.1738, 395.2061],
         [424.8294, 421.7952, 419.5968,  ..., 393.0822, 395.8900, 393.1715],
         ...,
         [392.0519, 397.7280, 398.4988,  ..., 398.1694, 396.6755, 398.2361],
         [400.2377, 396.8661, 396.8359,  ..., 399.0093, 398.7518, 400.2377],
         [404.3428, 404.0421, 399.0929,  ..., 398.3518, 402.8570, 404.3428]],

        [[327.8235, 329.5239, 333.7325,  ..., 285.4269, 288.2922, 292.8281],
         [340.1188, 340.1487, 340.2227,  ..., 285.2634, 288.8439, 288.9236],
         [338.3108, 335.2766, 333.0782,  ..., 282.1608, 286.9265, 286.8257],
         ...,
         [309.1382, 312.1176, 311.9801,  ..., 313.9214, 315.5509, 317.0721],
         [317.3240, 311.2557, 310.3172,  ..., 314.7612, 317.1865, 317.3240],
         [321.4291, 318.4316, 312.5742,  ..., 314.1038, 321.2916, 321.4291]],

        [[255.7246, 261.4700, 267.0410,  ..., 201.9358, 208.5293, 211.7169],
         [271.5615, 271.6625, 271.7604,  ..., 202.7995, 209.5271, 209.5832],
         [269.8168, 266.7826, 264.5842,  ..., 198.9806, 206.3089, 207.5169],
         ...,
         [244.2492, 247.2286, 247.0911,  ..., 242.5794, 245.2944, 246.9341],
         [252.4350, 246.3667, 245.4282,  ..., 243.4193, 248.2524, 252.4350],
         [256.5401, 253.5426, 247.6852,  ..., 242.7618, 252.3576, 256.5401]]])
torch.min(images[0]), torch.mean(images[0]), torch.max(images[0])
(tensor(56.4313), tensor(381.2817), tensor(484.6729))

So it looks like the image I grabbed from disk, comes out of the Dataset as -1,1 normalized…and yet when I put it into the cache, somehow its just raw. Both come out as tensors, which means they are going through torchvision.transforms, but it would appear that cache read isn’t getting normalized. I have just basic transforms in this test:

transform = transforms.Compose([
    transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
    ])

Trying to figure out whats going on, as all I do is that one step, and it shouldn’t be changing the image to that degree.

I guess ToTensor is not normalizing one data format, while the other is normalized to [0, 1].
Could you load a single sample into the same data structure and dtype and use ToTensor() on it?

Its the torchvision transforms that are not working when its in the cache. I know torchvision is made to work with PIL images. But I would think that the data coming out of the cache is byte for byte the same as the file loaded with cv2.imread() shouldn’t it?

@ptrblck I got it fixed. The issue was that self.shared_array was float32. I originally tried to fix this by trying to instantiate it as uint8 like so:

       shared_array_base = mp.Array(ctypes.c_uint, len(self.df)*c*h*w)
       shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())

but that was not enough. np.ctypeslib.as_array() does not have a dtype parameter, and I would have thought Numpy would do the c-type to dtype conversion automatically but it does not, and so the above just ended up creating it as uint32. So I had to coerce it to uint8. Once I did that it now displays images. I have not fully tested out the cache functionality yet, but I plan to do that now that I can put a good image in and get a good image out. My changed code is below:

       shared_array_base = mp.Array(ctypes.c_uint, len(self.df)*c*h*w)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        self.shared_array = shared_array.reshape(len(self.df), h, w, c)
        self.shared_array = self.shared_array.astype("uint8")
        self.use_cache = False
        
    def __getitem__(self, index):
        if not self.use_cache:
            im_path = os.path.join(self.imfolder, self.df.iloc[index]['image_name'] + '.jpg')
            x = cv2.imread(im_path)
            x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
            self.shared_array[index] = x
        x = self.shared_array[index] 

I have been able to test. What I am finding is that during the loading of the cache, I have to have num_workers set to 0. If I don’t then I just get no data. Example:

pytorch_dataloader.num_workers=1
 for idx, (x, y) in enumerate(pytorch_dataloader):
    print(f"Loading Cache {idx*12}.\r", end='')
pytorch_dataloader.dataset.set_use_cache(use_cache=True)
images,meta = next(iter(pytorch_dataloader))
images[0]
tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]],

        [[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]],

        [[-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         ...,
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.],
         [-1., -1., -1.,  ..., -1., -1., -1.]]])

yet this works:

pytorch_dataloader.num_workers=0
 for idx, (x, y) in enumerate(pytorch_dataloader):
    print(f"Loading Cache {idx*12}.\r", end='')
pytorch_dataloader.dataset.set_use_cache(use_cache=True)
images,meta = next(iter(pytorch_dataloader))
images[0]
tensor([[[ 0.7098,  0.5137,  0.5608,  ...,  0.2392,  0.2235, -0.0588],
         [ 0.7176,  0.5137,  0.6078,  ...,  0.0039, -0.0980,  0.3255],
         [ 0.7255,  0.4902,  0.6314,  ..., -0.0275,  0.5294,  0.7647],
         ...,
         [ 0.6235,  0.6392,  0.6627,  ...,  0.7255,  0.7098,  0.6784],
         [ 0.6314,  0.6314,  0.6471,  ...,  0.7020,  0.7020,  0.7098],
         [ 0.6471,  0.6471,  0.6627,  ...,  0.7098,  0.6784,  0.6863]],

        [[ 0.3333,  0.1608,  0.2000,  ..., -0.1137, -0.1373, -0.4196],
         [ 0.3412,  0.1608,  0.2627,  ..., -0.3569, -0.4588, -0.0353],
         [ 0.3647,  0.1294,  0.2863,  ..., -0.3882,  0.1843,  0.4118],
         ...,
         [ 0.3490,  0.3647,  0.3804,  ...,  0.4196,  0.4039,  0.3725],
         [ 0.3647,  0.3647,  0.3647,  ...,  0.3804,  0.3804,  0.3882],
         [ 0.3804,  0.3804,  0.3569,  ...,  0.3725,  0.3412,  0.3490]],

        [[ 0.3020,  0.1216,  0.2000,  ..., -0.2627, -0.2549, -0.5373],
         [ 0.3255,  0.1373,  0.2549,  ..., -0.4824, -0.5608, -0.1373],
         [ 0.3647,  0.1294,  0.2941,  ..., -0.4902,  0.1137,  0.3647],
         ...,
         [ 0.3804,  0.3961,  0.3961,  ...,  0.4431,  0.4275,  0.3961],
         [ 0.3725,  0.3725,  0.3804,  ...,  0.4275,  0.4275,  0.4353],
         [ 0.3882,  0.3882,  0.3804,  ...,  0.4275,  0.3961,  0.4039]]])

I must be doing something wrong that I am not able to have num_workers > 0 while populating the cache, if you have any ideas please let me know, I am troubleshooting.

I don’t know what might go wrong in your code.
Could you post a standalone code to reproduce this issue?
Maybe my example is missing something critical and only works in this isolated use case.

@ptrblck Here is as simple as I could make it. I also put it into a notebook with 100 256x256 images and the associated CSV on dropbox here if you want:

You can see I just do two tests. One where I set num_workers to 0 and load the cache, then flip use_cache to True, and everything seems to be fine. But if I have num_workers set to > 0 when I load the cache, then I get an empty image back.

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
import ctypes
import multiprocessing as mp
import os
import pandas as pd

class TestDataset(Dataset):
    def __init__(self, imfolder: str, transforms = None):
        
        self.imfolder = imfolder
        self.transforms = transforms
        self.n_images = len([f for f in os.listdir(imfolder) if f.endswith('.jpg')])
        self.df = pd.read_csv(f"{imfolder}/files.csv")
        c=3
        h=256
        w=256
        
        shared_array_base = mp.Array(ctypes.c_uint, self.n_images*c*h*w)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        self.shared_array = shared_array.reshape(self.n_images, h, w, c)
        self.shared_array = self.shared_array.astype("uint8")
        self.use_cache = False
        
    def __getitem__(self, index):
        if not self.use_cache:
            im_path = os.path.join(self.imfolder, self.df.iloc[index]['image_name'])
            x = cv2.imread(im_path)
            x = cv2.cvtColor(x, cv2.COLOR_BGR2RGB)
            self.shared_array[index] = x
        x = self.shared_array[index] 
        
        if self.transforms:
            x = self.transforms(x)
            
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache
    
    def __len__(self):
        return self.n_images
    
    
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
    ])
pytorch_dataset = TestDataset(imfolder= "./test", transforms=transform)
pytorch_dataloader = DataLoader(dataset=pytorch_dataset, batch_size=10, shuffle=True,  pin_memory=True, num_workers=1)



# Works
pytorch_dataset = TestDataset(imfolder= "./test", transforms=transform)
pytorch_dataloader = DataLoader(dataset=pytorch_dataset)
pytorch_dataloader.dataset.set_use_cache(use_cache=False)
pytorch_dataloader.num_workers = 0

for idx, x in enumerate(pytorch_dataloader):
    print(f"Loading Cache {idx*pytorch_dataloader.batch_size}.\r", end='')
    
pytorch_dataloader.dataset.set_use_cache(use_cache=True)

images = next(iter(pytorch_dataloader))
image = images[0]
plt.imshow(np.transpose(image, (1, 2, 0)))




# Does not work
pytorch_dataset = TestDataset(imfolder= "./test", transforms=transform)
pytorch_dataloader = DataLoader(dataset=pytorch_dataset)
pytorch_dataloader.dataset.set_use_cache(use_cache=False)
pytorch_dataloader.num_workers = 1

for idx, x in enumerate(pytorch_dataloader):
    print(f"Loading Cache {idx*pytorch_dataloader.batch_size}.\r", end='')
    
pytorch_dataloader.dataset.set_use_cache(use_cache=True)

images = next(iter(pytorch_dataloader))
image = images[0]
plt.imshow(np.transpose(image, (1, 2, 0)))

Thanks for the code.
Unfortunately I still cannot reproduce it using predefined arrays as seen here:

import torch
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import ctypes
import multiprocessing as mp
import os
import pandas as pd

class TestDataset(Dataset):
    def __init__(self, imfolder: str, transforms = None):

        self.imfolder = imfolder
        self.transforms = transforms
        self.n_images = 10
        c=3
        h=256
        w=256

        shared_array_base = mp.Array(ctypes.c_uint, self.n_images*c*h*w)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        self.shared_array = shared_array.reshape(self.n_images, h, w, c)
        self.shared_array = self.shared_array.astype("uint8")
        self.use_cache = False

    def __getitem__(self, index):
        if not self.use_cache:
            x = np.ones((256, 256, 3)).astype(np.uint8) * index
            self.shared_array[index] = x
        x = self.shared_array[index]

        if self.transforms:
            x = self.transforms(x)

        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def __len__(self):
        return self.n_images

transform = transforms.Compose([
    transforms.ToTensor(),
    ])
pytorch_dataset = TestDataset(imfolder= "./test", transforms=transform)
pytorch_dataloader = DataLoader(dataset=pytorch_dataset, batch_size=10, shuffle=True,  pin_memory=True, num_workers=1)

# Works
pytorch_dataset = TestDataset(imfolder= "./test", transforms=transform)
pytorch_dataloader = DataLoader(dataset=pytorch_dataset)
pytorch_dataloader.dataset.set_use_cache(use_cache=False)
pytorch_dataloader.num_workers = 0

for idx, x in enumerate(pytorch_dataloader):
    print("Loading Cache {}".format(idx*pytorch_dataloader.batch_size))
    print(x.sum())

pytorch_dataloader.dataset.set_use_cache(use_cache=True)

images = next(iter(pytorch_dataloader))
image = images[0]
print(image.sum())

# Does not work
pytorch_dataset = TestDataset(imfolder= "./test", transforms=transform)
pytorch_dataloader = DataLoader(dataset=pytorch_dataset)
pytorch_dataloader.dataset.set_use_cache(use_cache=False)
pytorch_dataloader.num_workers = 1

for idx, x in enumerate(pytorch_dataloader):
    print("Loading Cache {}".format(idx*pytorch_dataloader.batch_size))
    print(x.sum())

pytorch_dataloader.dataset.set_use_cache(use_cache=True)

images = next(iter(pytorch_dataloader))
image = images[0]
print(image.sum())

I removed the Normalization to manually verify that the result tensors show the expected results and not random values.
Also, since I’m working on a remote server, I just used the sum() of the output tensors instead of matplotlib.

Here is the output:

Loading Cache 0
tensor(0.)
Loading Cache 1
tensor(771.0118)
Loading Cache 2
tensor(1542.0236)
Loading Cache 3
tensor(2313.0354)
Loading Cache 4
tensor(3084.0471)
Loading Cache 5
tensor(3855.0588)
Loading Cache 6
tensor(4626.0708)
Loading Cache 7
tensor(5397.0825)
Loading Cache 8
tensor(6168.0942)
Loading Cache 9
tensor(6939.1060)
tensor(0.)
Loading Cache 0
tensor(0.)
Loading Cache 1
tensor(771.0118)
Loading Cache 2
tensor(1542.0236)
Loading Cache 3
tensor(2313.0354)
Loading Cache 4
tensor(3084.0471)
Loading Cache 5
tensor(3855.0588)
Loading Cache 6
tensor(4626.0708)
Loading Cache 7
tensor(5397.0825)
Loading Cache 8
tensor(6168.0942)
Loading Cache 9
tensor(6939.1060)
tensor(0.)

Such a strange problem! I put code in the Dataset to watch index 0, as all indexes in the cache are being affected. With the first case, where we load cache with num_workers=0, everything is fine, loading the cache and then reading it back out. With the second case, where num_workers=1 when loading the cache, this is what happens:

Current index 0, value of cache index 0 object is 35830685
Current index 1, value of cache index 0 object is 35830685
Current index 2, value of cache index 0 object is 35830685
Current index 3, value of cache index 0 object is 35830685
.
.
.
.
Current index 96, value of cache index 0 object is 35830685
Current index 97, value of cache index 0 object is 35830685
Current index 98, value of cache index 0 object is 35830685
Current index 99, value of cache index 0 object is 35830685
(now reading back out)
Current index 0, value of cache index 0 object is 0
Current index 1, value of cache index 0 object is 0
Current index 2, value of cache index 0 object is 0
Current index 3, value of cache index 0 object is 0
Current index 4, value of cache index 0 object is 0
.
.
.
.

So during the time the cache is being loaded, that index is not affected. But it must get hosed when the load finishes? Because by the time you goto read it back its already messed up.

@ptrblck more progress. I eliminated all the transform, cv2, images…but I left pandas. Problem is still there. All you need is a dummy CSV file, no actual images. So you can just save this as files.csv:

image_name,value
ISIC_0030000.jpg,1
ISIC_0030001.jpg,2
ISIC_0030002.jpg,3
ISIC_0030003.jpg,4
ISIC_0030004.jpg,5
ISIC_0030005.jpg,6
ISIC_0030006.jpg,7
ISIC_0030007.jpg,8
ISIC_0030008.jpg,9
ISIC_0030009.jpg,10

This is the updated code:

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import ctypes
import multiprocessing as mp
import pandas as pd

class TestDataset(Dataset):
    def __init__(self, imfolder: str):
        
        self.imfolder = imfolder
        self.df = pd.read_csv(f"{imfolder}/files.csv")
        self.n_images = len(self.df) 

        c=3
        h=256
        w=256
        
        shared_array_base = mp.Array(ctypes.c_uint, self.n_images*c*h*w)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        self.shared_array = shared_array.reshape(self.n_images, h, w, c)
        self.shared_array = self.shared_array.astype("uint8")
        self.use_cache = False
        
    def __getitem__(self, index):
        if not self.use_cache:
            self.shared_array[index] = self.df.iloc[index]['value'] # x
        x = self.shared_array[index] 
        print(f" Current index {index}, value of cache index 0 object is {self.shared_array[0].sum()}")
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache
    
    def __len__(self):
        return self.n_images
    
pytorch_dataset = TestDataset(imfolder= "./test")
pytorch_dataloader = DataLoader(dataset=pytorch_dataset, batch_size=1, shuffle=True,  pin_memory=True, num_workers=1)

# Works
print("Works")
pytorch_dataset = TestDataset(imfolder= "./test")
pytorch_dataloader = DataLoader(dataset=pytorch_dataset)
pytorch_dataloader.dataset.set_use_cache(use_cache=False)
pytorch_dataloader.num_workers = 0

for idx, x in enumerate(pytorch_dataloader):
    next
#     print(f"Loading Cache {idx*pytorch_dataloader.batch_size}.\r", end='')
#     print(x.sum())

pytorch_dataloader.dataset.set_use_cache(use_cache=True)

for idx, x in enumerate(pytorch_dataloader):
    next
#     print(f"x: {mylist[idx]:.5f}   x from cache: {x.sum():.5f}")


# Does not work
print("Does not work")
pytorch_dataset = TestDataset(imfolder= "./test")
pytorch_dataloader = DataLoader(dataset=pytorch_dataset)
pytorch_dataloader.dataset.set_use_cache(use_cache=False)
pytorch_dataloader.num_workers = 1

for idx, x in enumerate(pytorch_dataloader):
    next
#     print(f"Loading Cache {idx*pytorch_dataloader.batch_size}.\r", end='')
#     print(x.sum())

pytorch_dataloader.dataset.set_use_cache(use_cache=True)

for idx, x in enumerate(pytorch_dataloader):
    next
    #     print(f"x: {mylist[idx]:.5f}   x from cache: {x.sum():.5f}")


And here is the output:

Works
 Current index 0, value of cache index 0 object is 196608
 Current index 1, value of cache index 0 object is 196608
 Current index 2, value of cache index 0 object is 196608
 Current index 3, value of cache index 0 object is 196608
 Current index 4, value of cache index 0 object is 196608
 Current index 5, value of cache index 0 object is 196608
 Current index 6, value of cache index 0 object is 196608
 Current index 7, value of cache index 0 object is 196608
 Current index 8, value of cache index 0 object is 196608
 Current index 9, value of cache index 0 object is 196608
 Current index 0, value of cache index 0 object is 196608
 Current index 1, value of cache index 0 object is 196608
 Current index 2, value of cache index 0 object is 196608
 Current index 3, value of cache index 0 object is 196608
 Current index 4, value of cache index 0 object is 196608
 Current index 5, value of cache index 0 object is 196608
 Current index 6, value of cache index 0 object is 196608
 Current index 7, value of cache index 0 object is 196608
 Current index 8, value of cache index 0 object is 196608
 Current index 9, value of cache index 0 object is 196608
Does not work
 Current index 0, value of cache index 0 object is 196608
 Current index 1, value of cache index 0 object is 196608
 Current index 2, value of cache index 0 object is 196608
 Current index 3, value of cache index 0 object is 196608
 Current index 4, value of cache index 0 object is 196608
 Current index 5, value of cache index 0 object is 196608
 Current index 6, value of cache index 0 object is 196608
 Current index 7, value of cache index 0 object is 196608
 Current index 8, value of cache index 0 object is 196608
 Current index 9, value of cache index 0 object is 196608
 Current index 0, value of cache index 0 object is 0
 Current index 1, value of cache index 0 object is 0
 Current index 2, value of cache index 0 object is 0
 Current index 3, value of cache index 0 object is 0
 Current index 4, value of cache index 0 object is 0
 Current index 5, value of cache index 0 object is 0
 Current index 6, value of cache index 0 object is 0
 Current index 7, value of cache index 0 object is 0
 Current index 8, value of cache index 0 object is 0
 Current index 9, value of cache index 0 object is 0

@ptrblck actually your code does show the problem, at least on my system.

The value you got back was “0”…and the next value you will get back is 0 as well, and they will all be 0’s. During the loading of the cache…whether its num_workers = 0 or num_workers = 1, you will get the correct values back, as your output shows. But once that iterator completes…the values are hosed in the case of num_workers = 1. you can verify this by iterating again as you see in my output, I look at index 0 and its value is 0 (using your np.ones code I look at the value of index 1, since index 0 is always “0” anyways).

Here is updated code with just using np.ones() but more verbose output:

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import ctypes
import multiprocessing as mp

class TestDataset(Dataset):
    def __init__(self):
        
        self.n_images = 10

        shared_array_base = mp.Array(ctypes.c_uint, self.n_images*3*256*256)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        self.shared_array = shared_array.reshape(self.n_images, 256, 256, 3)
        self.shared_array = self.shared_array.astype("uint8")
        self.use_cache = False
        
    def __getitem__(self, index):
        if not self.use_cache:
            self.shared_array[index] =  np.ones((256, 256, 3)).astype(np.uint8) * index
        x = self.shared_array[index] 
        return x

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache
    
    def __len__(self):
        return self.n_images
    
# Works
print("Works")
pytorch_dataset = TestDataset()
pytorch_dataloader = DataLoader(dataset=pytorch_dataset)
pytorch_dataloader.dataset.set_use_cache(use_cache=False)
pytorch_dataloader.num_workers = 0

for idx, x in enumerate(pytorch_dataloader):
    print(f"Loading Cache {idx*pytorch_dataloader.batch_size}.\r", end='')
    print(f"Index {idx} = {x.sum()}")

pytorch_dataloader.dataset.set_use_cache(use_cache=True)

for idx, x in enumerate(pytorch_dataloader):
    print(f"Read from cache Index {idx} = {x.sum()}")
    
# Does not work
print("Does not work")
pytorch_dataset = TestDataset()
pytorch_dataloader = DataLoader(dataset=pytorch_dataset)
pytorch_dataloader.dataset.set_use_cache(use_cache=False)
pytorch_dataloader.num_workers = 1

for idx, x in enumerate(pytorch_dataloader):
    print(f"Loading Cache {idx*pytorch_dataloader.batch_size}.\r", end='')
    print(f"Index {idx} = {x.sum()}")

pytorch_dataloader.dataset.set_use_cache(use_cache=True)

for idx, x in enumerate(pytorch_dataloader):
    print(f"Read from cache Index {idx} = {x.sum()}")


Produces:

Works
Index 0 = 0he 0.
Index 1 = 196608
Index 2 = 393216
Index 3 = 589824
Index 4 = 786432
Index 5 = 983040
Index 6 = 1179648
Index 7 = 1376256
Index 8 = 1572864
Index 9 = 1769472
Read from cache Index 0 = 0
Read from cache Index 1 = 196608
Read from cache Index 2 = 393216
Read from cache Index 3 = 589824
Read from cache Index 4 = 786432
Read from cache Index 5 = 983040
Read from cache Index 6 = 1179648
Read from cache Index 7 = 1376256
Read from cache Index 8 = 1572864
Read from cache Index 9 = 1769472
Does not work
Index 0 = 0he 0.
Index 1 = 196608
Index 2 = 393216
Index 3 = 589824
Index 4 = 786432
Index 5 = 983040
Index 6 = 1179648
Index 7 = 1376256
Index 8 = 1572864
Index 9 = 1769472
Read from cache Index 0 = 0
Read from cache Index 1 = 0
Read from cache Index 2 = 0
Read from cache Index 3 = 0
Read from cache Index 4 = 0
Read from cache Index 5 = 0
Read from cache Index 6 = 0
Read from cache Index 7 = 0
Read from cache Index 8 = 0
Read from cache Index 9 = 0

My code snippet still returns the right values on my nodes after adding the “index multiplication” and your last loop:

import torch
from torch.utils.data import Dataset, DataLoader

import ctypes
import multiprocessing as mp

import numpy as np


class MyDataset(Dataset):
    def __init__(self):
        shared_array_base = mp.Array(ctypes.c_float, nb_samples*c*h*w)
        shared_array = np.ctypeslib.as_array(shared_array_base.get_obj())
        shared_array = shared_array.reshape(nb_samples, c, h, w)
        self.shared_array = torch.from_numpy(shared_array)
        self.use_cache = False

    def set_use_cache(self, use_cache):
        self.use_cache = use_cache

    def __getitem__(self, index):
        if not self.use_cache:
            print('Filling cache for index {}'.format(index))
            # Add your loading logic here
            self.shared_array[index] = torch.ones(c, h, w) * index
        x = self.shared_array[index]
        return x

    def __len__(self):
        return nb_samples


nb_samples, c, h, w = 10, 3, 24, 24

dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=2,
    shuffle=False
)

for epoch in range(2):
    for idx, data in enumerate(loader):
        print('Epoch {}, idx {}, data.shape {}'.format(epoch, idx, data.sum()))

    if epoch==0:
        loader.dataset.set_use_cache(True)

loader.dataset.set_use_cache(use_cache=True)

for idx, x in enumerate(loader):
        print("Read from cache Index {} = {}".format(idx, x.sum()))

Output:

Filling cache for index 0
Filling cache for index 1
Filling cache for index 2
Filling cache for index 3
Filling cache for index 4
Epoch 0, idx 0, data.shape 0.0
Epoch 0, idx 1, data.shape 1728.0
Filling cache for index 5
Epoch 0, idx 2, data.shape 3456.0
Filling cache for index 6
Epoch 0, idx 3, data.shape 5184.0
Filling cache for index 7
Epoch 0, idx 4, data.shape 6912.0
Filling cache for index 8
Epoch 0, idx 5, data.shape 8640.0
Filling cache for index 9
Epoch 0, idx 6, data.shape 10368.0
Epoch 0, idx 7, data.shape 12096.0
Epoch 0, idx 8, data.shape 13824.0
Epoch 0, idx 9, data.shape 15552.0
Epoch 1, idx 0, data.shape 0.0
Epoch 1, idx 1, data.shape 1728.0
Epoch 1, idx 2, data.shape 3456.0
Epoch 1, idx 3, data.shape 5184.0
Epoch 1, idx 4, data.shape 6912.0
Epoch 1, idx 5, data.shape 8640.0
Epoch 1, idx 6, data.shape 10368.0
Epoch 1, idx 7, data.shape 12096.0
Epoch 1, idx 8, data.shape 13824.0
Epoch 1, idx 9, data.shape 15552.0
Read from cache Index 0 = 0.0
Read from cache Index 1 = 1728.0
Read from cache Index 2 = 3456.0
Read from cache Index 3 = 5184.0
Read from cache Index 4 = 6912.0
Read from cache Index 5 = 8640.0
Read from cache Index 6 = 10368.0
Read from cache Index 7 = 12096.0
Read from cache Index 8 = 13824.0
Read from cache Index 9 = 15552.0

That being said, I’m not familiar enough with potential side effects of mp.Array or any OS-dependent issues to be of any real help. :confused:
My best guess is that the change in dtype might create these issues (as it seems to be the one change you’ve introduced). If you are not getting the right values with this code snippet, I would guess some multiprocessing or OS - related differences.

Thanks for your help. So, this all comes down to a c_uint is not a np.uint8. To be compatible it is c_ubyte is equal to an np.uint8. And beyond that you know I have coded my share of low level memory managers where I have run into such strange things, but I still cannot fathom why you could a) write to the cache, b) read it back and its fine, and only after everything was done were things hosed. Unless, there is some sort of caching happening under the hood, which is possible.

It now works as it should!