Data missing with torchdata + ddp?

I’m trying to test my model with torchdata and ddp (2 gpus or sometimes 6 gpus), however, the final number of samples is always less than the actural number here.

My datadir has 17 packed tar files, number 0~15 has 5120 entries, while the last one has only 2730 entries. If I add shardingfilter after filelister, it seems that all of datas in the last tar file will be ignored by the dataloader. I also tried to adding shardingfilter after the map function, however, there are still servel samples missing .

Is there anyway I can make sure I can have the correct number of files? Do I have to repack my dataset so they all have the same number of entries?
Many thanks in advance!

    import torchdata
    from torchdata.datapipes.iter import FileLister, FileOpener

    rank, world_size = get_dist_info()
    rootdir = "/data/test/"
    dataset = FileLister(rootdir, "*.tar")

    # if dist: dataset = dataset.sharding_filter()
    dataset = FileOpener(dataset, mode="rb")
    dataset = dataset.load_from_tar(length=length)
    dataset = dataset.webdataset().map(postprocess_func)
    if dist: dataset = dataset.sharding_filter()

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=shuffle,
        drop_last=False,
        **kwargs)
    
    cnts = 0
    for ind, x in enumerate(data_loader):
        cnts += len(x)
    # cnts doesnt match here

For distributed training, the data needs to be balanced across distributed processes. When you shard data over files (with 2 gpus), there will be one rank having more data from the last tar file than the another rank. Technical speaking, I think your script should be hanging at the end of epoch caused by the unbalanced data.

It depends on your batch_size. The total amount of data is 5120 * 16 + 2730 = 84650, then the total batches would be 84650 / batch_size. If this number can not be divided by the number of GPUs without remainders, then a few samples will be ignored. And, your distributed process should be hanging as well.

BTW, to prevent hanging problem, we do provide a DataPipe operation called fullsync in this PR Implement FullSyncIterDataPipe by ejguan · Pull Request #713 · pytorch/data · GitHub. You can attach it at the end of the pipeline. To get it, you need to use torchdata nightly release via the following two options:

  • pip3 install --pre torch torchdata --extra-index-url https://download.pytorch.org/whl/nightly/cpu
  • conda install pytorch torchdata -c pytorch-nightly