Hello,
I’m running single-node, multi-GPU training using data in WebDataset format (~4000 shards, each with 1000 .flac
files). My goal is to train for N
epochs, where each epoch consists of config["num_batches_per_epoch"]
optimizer steps.
Here’s what I want to achieve:
- Continuous sampling across epochs, wrapping around when all data have been used.
- Proper reshuffling at both the shard and sample levels on every epoch.
- Different data per GPU, using
DistributedSampler
-like behavior. - And most importantly: ensuring that running the training long enough will eventually sample all the available data and not just a subset of them.
Here’s the dataset definition I’m currently using:
train_dataset = (
wds.WebDataset(
config["train_data"],
resampled=True,
shardshuffle=True,
handler=wds.handlers.warn_and_continue,
nodesplitter=wds.shardlists.split_by_node,
)
.shuffle(config["shuffle_buffer_size"])
.decode(wds.torch_audio)
.map(lambda sample: sample["flac.flac"][0])
.with_epoch(
config["num_batches_per_epoch"]
* config["batch_size"]
// config["num_workers"]
)
)
train_dataloader = DataLoader(
train_dataset,
batch_size=config["batch_size"],
num_workers=config["num_workers"],
worker_init_fn=seed_worker,
pin_memory=(device.type == "cuda"),
)
Is this the right way to do it?
Thanks in advance!