FSDP takes a very long time when start iterate DataLoader of each epoch

I’m trying to train a CLIP-based model using FSDP on three GPUs, but when each epoch start iterating the data loader, it may take about 1 minute for reading 4K pieces of data.
DataSet key codes:

class MyDataset(DataSet):
    def __init__(self, …):
        self.label_table = pd.read_csv(…) # About 4K rows

    def __len__(self):
        return self.y.shape[0]
    def __getitem__(self, idx):
        img = Image.open(join(self.img_dir, self.label_table.loc[idx, 'image']))
        if self.transform is not None:
            img = self.transform(img)
        return img, …

DataLoader related codes:

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((image_size, ) * 2),

train_ds = MyDataset(…, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds, rank=rank, num_replicas=args.world_size)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, pin_memory=False, num_workers=48, sampler=train_sampler, drop_last=True)

Trainning codes:

model = MyModel(train_ds.get_ti_rads_size()).float().to(rank)
model = FullyShardedDataParallel(model, auto_wrap_policy=my_auto_wrap_policy, cpu_offload=CPUOffload(offload_params=True), use_orig_params=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for epoch in range(1, args.epochs + 1):
    fsdp_loss = torch.zeros(2).to(rank)
    if train_sampler:

    if rank == 0:
        epoch_progress = tqdm.tqdm(range(len(train_loader)), colour='blue', desc=f'Training Epoch {epoch}')
    # The hanging out problem appears on the line below of each epoch
    for image, …, _ in train_loader:
        # Copy tensors to devices.
        image = image.to(rank)
        # …

The program will stuck at the for image, …, _ in train_loader: for about 1 minute of each epoch.
When I add debug print line at the __getitem__ function of the DataSet, it seems that all images will be read again when DataLoader is iterated every time. So I would like to know how too avoid this problem?

I have tried some approaches but no one worked:

  1. Set pin_memory=True.
  2. Read all images in the constructor function of DataSet, and read image tenors directly in the __getitem__ function. (It might illustrate the DataSet would be reload at every epoch.

Your current MyDataset.__init__ method only loads self.label_table via pd.read_csv, so I assume you are referring to this line of code?
This behavior is expected since each new DataLoader iterator will recreate the Dataset on each worker. If you want to avoid this behavior you could use persistent_workers=True in the DataLoader, which will reuse the workers in each epoch.

1 Like

It works! Thank you very much! :grinning: