Reset DataLoader

I’m using 4 DataLoaders for optimized BERT training. I have 4 datasets with different sequence lengths (8, 16, 32, 64), batch sizes are (512, 256, 128, 64) correspondingly. All this done for memory consumption control.

I’m calculating backward for each batch from my DataLoaders and then I’m making an optimizer step. Each DataLoader has different length. How can I reset all DataLoaders in case of using zip function?

I will provide some example code:

data_loaders = []
for size in [8, 16, 32, 64]:
    data_loaders.append(
        DataLoader(
            dataset=dataset[size],
            collate_fn=collate_fn,
            batch_size=num_samples // size,
            num_workers=8,
            shuffle=True,
            pin_memory=True,
            drop_last=True,
        )
    )
for epoch in range(1, num_epochs + 1):
    for i, batches in enumerate(zip(*data_loaders)):
        for batch, labels in batches:
            batch = batch.to(device)
            labels = labels.to(device)
            
            with autocast(enabled=use_amp):
                predict = _model(
                    batch
                )
                loss = criterion(
                    predict,
                    labels,
                )

            scaler.scale(loss).backward()

        scaler.step(_optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)