Issue with DistributedSampler and Custom Batch Sampler in PyTorch Lightning

Hi everyone,

I’m working on a Pytorch Lightning pipeline (on a machine with 4 H100 GPUs) where I need to pass a sorted dataset (audio here) into a custom batch sampler to (1) bucket these samples and (2) batch them down the line. My goal is to use the sorted indices for custom batch sampling to minimize padding (since sorting can bring similar lengths together and can reduce padding within a batch) during data loading in the collate_fn (which currently uses a zero padding technique which pads all samples to the max sequence).
Below is my get_dataloader code which is called in train code:

if len(dataset_list) == 1:
    dataset = dataset_list[0]
else:
    dataset = torch.utils.data.ConcatDataset(dataset_list)

# Sort dataset by audio length in decreasing order
sorted_indices = sorted(range(len(dataset)), key=lambda i: dataset[i].audio.shape[1], reverse=True)
sorted_dataset = [dataset[i] for i in sorted_indices]

dataset = sorted_dataset

# Use DistributedSampler without shuffling
sampler = DistributedSampler(dataset, shuffle=False)

# Custom BucketBatchSampler with bucketing logic
bucket_sampler = BucketBatchSampler(
    sampler, 
    batch_size=self.config.physical_batch_size,
    drop_last=train,
    buckets=25,
)

# Create DataLoader with the custom batch sampler
dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_sampler=bucket_sampler,
    collate_fn=dataset_list[0].collate_fn,
    num_workers=4,
    pin_memory=True,
    shuffle=train,
)

The issue I’m encountering is that when I manually initialize DistributedSampler (also turned off use_distributed_sampler=False in Trainer initalization), the below error occurs. If I let Lightning handle the distributed setup without manually changing the argument shuffle=False, no error occurs but it disrupts my sorted indices (shuffles it) and my bucketing logic needs the sampler to output sorted indices.

Interestingly, If i use a sampler =SequentialSampler(dataset), it overrides this with the DistributedSampler again…

Here’s the error I receive when using DistributedSampler(dataset, shuffle=False):

ValueError: Default process group has not been initialized, please make sure to call init_process_group.

Has anyone faced a similar challenge, or does anyone know how to prevent the dataset from being shuffled while using DistributedSampler in this context? I’m specifically looking for a way to maintain the sorted order for my custom batch sampling logic.

Also providing my bucket_sampler logic for more context:

class BucketBatchSampler(BatchSampler):
    def __init__(self, sampler, batch_size, drop_last=False, buckets=25, lengths_dict=None):
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.buckets = buckets
    
    def __iter__(self):
        # Get indices from the sampler (this will be different for each GPU)
        # print("SAMPLER NAME within bucket batch", self.sampler)
        indices = list(self.sampler)
        print("sorted_indices", indices)
        


        if self.drop_last or len(indices) % self.buckets == 0:
            bucket_size = len(indices) // self.buckets
        else:
            bucket_size = (len(indices) // self.buckets) + 1

        # Create buckets
        buckets = [
            indices[i:i + bucket_size]
            for i in range(0, len(indices), bucket_size)
        ]
        print("bucket", buckets)

        # Shuffle buckets
        random.shuffle(buckets)
        
        for bucket in buckets:
            # Shuffle samples within the bucket
            random.shuffle(bucket)
            
            # Yield batches from the bucket
            for i in range(0, len(bucket), self.batch_size):
                batch = bucket[i:i + self.batch_size]
                if len(batch) == self.batch_size or not self.drop_last:
                    print("length of each batch should be 8", len(batch))
                    print("batch post bucketing", batch)
                    yield batch

    def __len__(self):
        if self.drop_last:
            length = len(self.sampler) // self.batch_size
            # print("total number of batches (drop_last=true)", length)
            return length
        else:
            length = (len(self.sampler) + self.batch_size - 1) // self.batch_size
            # print("total number of batches (drop_last=False)", length)
            return length

this is a good answer: Error while using custom DistributedSampler · Lightning-AI/pytorch-lightning · Discussion #7573 · GitHub but I still don’t understand where to initialize the process group?

I understand that I need to follow the rules mentioned here (as suggested by @ptrblck in one of the discussions, but what should I initialize for these parameters?

os.environ[‘MASTER_ADDR’] = ‘localhost’
os.environ[‘MASTER_PORT’] = ‘12355’

dist.init_process_group(“nccl”, rank=rank, world_size=world_size)

Thanks in advance!