DDP with an out-of-memory, generator datapipe

For context, my dataset is a set of parquet files, each with a variable amount of rows. I am trying to train using DDP, but my dataset it too large to load into one process, let alone multiple. When training on one GPU, it is simple enough to set up a generator using pyarrow. I create a generator for each parquet file and chain them together, inputting the result to a DataLoader. Here is some example code:

import pyarrow.parquet as pq
from itertools import chain
from torchdata.datapipes.iter import IterableWrapper, Shuffler

files = [<list of parquet files>]

def get_datapipe(files):
    all_parquet_gen = chain(*[get_parquet_generator(f) for f in files])
    my_datapipe = Shuffler(IterableWrapper(all_parquet_gen, deepcopy=False), buffer_size=bs*4)

def get_parquet_generator(f):
    parquet_file = pq.ParquetFile(io.BytesIO(obj['Body'].read()))
    gen = parquet_file.iter_batches(batch_size=1)
    for batch in gen:
        yield single_record_to_torch(batch.to_pylist()[0])

def single_record_to_torch(rec):
    <Transforms a single record from pyarrow to a torch-friendly format>

Currently, just trying to run DDP with the same code gives the following error:

TypeError: cannot pickle 'generator' object

Here is my code for running DPP:

from dataset import get_datapipe
from model import MyModel

from torch.optim import Adam
from torch.nn import MSELoss

import torch
from tqdm import tqdm

import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():

# Is this the correct way to do a loop with DDP?
def train_model(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    files = get_keys_paginator('gogneptune', 'gogv2_input/v2-1/full_prefiltered')
    train_dl = get_datapipe(files)
    dummy_dl = get_datapipe(files[:2])
    batch = next(iter(dummy_dl)).to(rank)

    model = MyModel()
    model = model.to(rank)
    model(batch.x_dict, batch.edge_index_dict, batch.batch_dict)
    ddp_model = DDP(model, device_ids=[rank])
    dummy_dl = None

    optimizer = Adam(params=model.parameters())
    criterion = MSELoss().to(rank)

    n_epoch = 5
    for n in range(n_epoch):
        for batch in tqdm(train_dl):
            batch = batch.to(rank)
            output = ddp_model(batch.x_dict, batch.edge_index_dict, batch.batch_dict)
            y = batch.y
            loss = criterion(output, y.reshape(-1,1))


def run_training(train_fn, world_size):

if __name__ == '__main__':
    num_devices = torch.cuda.device_count() 
    run_training(train_model, num_devices)

I am wondering how I can adapt my workflow to a DDP setting, given that my data is out-of-memory, and the nature of parquet files having a variable number of rows. (So I can’t simply shard the files list, as the shards would still not have an equal amount of data.) If it changes anything, I am working on a single machine with multiple GPUs. Advice is appreciated!