I’m trying to train a CLIP-based model using FSDP on three GPUs, but when each epoch start iterating the data loader, it may take about 1 minute for reading 4K pieces of data.
DataSet key codes:
class MyDataset(DataSet):
def __init__(self, …):
self.label_table = pd.read_csv(…) # About 4K rows
def __len__(self):
return self.y.shape[0]
def __getitem__(self, idx):
img = Image.open(join(self.img_dir, self.label_table.loc[idx, 'image']))
if self.transform is not None:
img = self.transform(img)
return img, …
DataLoader related codes:
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((image_size, ) * 2),
torchvision.transforms.ToTensor()
])
train_ds = MyDataset(…, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds, rank=rank, num_replicas=args.world_size)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, pin_memory=False, num_workers=48, sampler=train_sampler, drop_last=True)
Trainning codes:
model = MyModel(train_ds.get_ti_rads_size()).float().to(rank)
model = FullyShardedDataParallel(model, auto_wrap_policy=my_auto_wrap_policy, cpu_offload=CPUOffload(offload_params=True), use_orig_params=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
for epoch in range(1, args.epochs + 1):
fsdp_loss = torch.zeros(2).to(rank)
if train_sampler:
train_sampler.set_epoch(epoch)
if rank == 0:
epoch_progress = tqdm.tqdm(range(len(train_loader)), colour='blue', desc=f'Training Epoch {epoch}')
# The hanging out problem appears on the line below of each epoch
for image, …, _ in train_loader:
# Copy tensors to devices.
image = image.to(rank)
# …
The program will stuck at the for image, …, _ in train_loader:
for about 1 minute of each epoch.
When I add debug print line at the __getitem__
function of the DataSet, it seems that all images will be read again when DataLoader is iterated every time. So I would like to know how too avoid this problem?
I have tried some approaches but no one worked:
- Set
pin_memory=True
. - Read all images in the constructor function of DataSet, and read image tenors directly in the
__getitem__
function. (It might illustrate the DataSet would be reload at every epoch.