I have a dataset that I’m feeding into a NN, but the dataset is larger than the available memory on my machine (training on single CPU with 24 cores and 32 GB memory). I’m trying to load the data in batches and train in parallel, but for some reason all of the memory for all of the batches is being used at once, and the system crashes (If my total dataset is 40 GB, instead of loading (40GB/nbatches) * nprocesses, it loads 40GB at once).
I’ve tried various implementations of DDP, torch.multiprocessing, torch.distributed, and DataParallel, but I haven’t been able to figure it out.
The code is essentially:
class Dataset(torch.utils.data.Dataset):
'Characterizes a dataset for PyTorch'
def __init__(self):
'Initialization'
self.srcDir = os.getcwd() + '/batches'
self.data = os.listdir(self.srcDir)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
fname = os.path.join(self.srcDir,
'batches_{}.pckl'.format(index))
with open(fname, 'rb') as f1:
X = pickle.load(f1)
def parallelFit():
from torch.nn.parallel import DistributedDataParallel as DDP
train_dataset = Dataset()
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
batches = torch.utils.DataLoader(dataset=train_dataset, sampler=train_sampler)
ddp_model = DDP(self.model, device_ids=[])
optimizer = LBFGS(ddp_model.parameters())
optimizer.step()
###I see the same issue if I don't use DDP and if I don't use DistributedSampler
Then everything is run with:
processes = []
for rank in range(num_processes):
p = Process(target=calc.parallelFit, args=(partitions.use(rank)))
p.start()
processes.append(p)
for p in processes:
p.join()
I’ve been reading through the forums and the documentation, but I feel like I’m just missing something.Preformatted text