Hey folks,
I have a server with large amounts of RAM, but slow storage and I want to speed up training by having my dataset in the RAM. I also use DDP
which means there are going to be multiple processes per GPU. On top of that, I use multiple num_workers
in my dataloader
so having a simple Python list as a caxhe would mean multiple caches which eats up a lot of memory.
The natural solution is to use shared memory. And this is how I use it
- In the launch process, do
if __name__ == '__main__':
import argparse
import os
import torch.multiprocessing as mp
import ctypes
shared_base = mp.Array(ctypes.c_byte, 80000*3*256*256, lock=True)
with shared_base.get_lock():
shared_array = np.ctypeslib.as_array(shared_base.get_obj())
img_cache = shared_array.reshape(80000, 256, 256, 3)
use_cache = mp.Array(ctypes.c_float, 1, lock=False)
use_cache[0] = -3.14
- This cache is sent to each process as
mp.spawn(main, nprocs=ngpus_per_node, args=(args, img_cache, use_cache))
- Each process takes it this shared memory and gives it to a dataset object
dset = SVAE_FFHQ(args.data_folder, transform, 32, 64, args.hidden_size, img_cache, use_cache)
- The
SVAE_FFHQ
class does looks like this:
class SVAE_FFHQ(data.Dataset):
def __init__(self, root_dir, transform=None, top_size=32, bottom_size=64, dim=256, img_cache=None, use_cache=None):
super().__init__()
...
self.img_cache = img_cache
self.use_cache = use_cache
def _use_cache(self):
self.use_cache[0] = 3.14
print('Using cache')
def __getitem__(self, idx):
path, lbl = self.dset.samples[idx]
if self.use_cache[0] < 0:
with open(path, 'rb') as f:
img = Image.open(f)
img = img.convert('RGB')
img = img.resize((256, 256), Image.LANCZOS)
self.img_cache[idx] = deepcopy(np.asarray(img))
del img
return self.transform(Image.fromarray(self.img_cache[idx], 'RGB'))
This to me seems fine, but what happens is
- The shared memory is pickled and not replicated across the multiple spawned processes which means my memory requirments increase with the number of processes spawned.
- This isn’t any faster than reading data off of slow HDDs
Any insight into these problems?
Thanks!