How to train only about a specific indices in DistributedSampler?

Hi all,
I am trying DistributedSampler and Dataloader for my custom dataset and the model.
But I’m in trouble using specific indice in Dataloader.
I want to train my model by random shuffling only for a specific indices from the whole data.

For example, Here is my example code:

from torch.utils.data.distributed import DistributedSampler as DistributedSampler

train_dataset = MyDataset(args.data_dir, split=‘xxx’, input_size=input_size)
train_sampler = DistributedSampler(train_dataset, num_replicas=args.gpus, rank=gpu)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=int(args.batch_size/args.gpus), shuffle=False, sampler=train_sampler, collate_fn=train_dataset.collate_fn, num_workers=int(args.workers/args.gpus), pin_memory=True)

I think I need to modify part sampler.
But I have no idea how to make this function to use specific indices.
Is there a way to implement the function I want there?

P.S.
When I used Dataparallel, I can use specific indices like below:

specific_indices = [10, 99, 5, 8, …]
train_dataset = MyDataset(root=xxx, transform=xxx)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers, sampler=SubsetRandomSampler(specific_indices), collate_fn=detection_collate, pin_memory=True)

Thank you in advance.

I also had the same issue with Weighted Random Sampler and I found This reply from @ptrblck to be the best one
So based on the source code it would change like this

class DistributedSubset(Sampler):
    """
    https://discuss.pytorch.org/t/how-to-use-my-own-sampler-when-i-already-use-distributedsampler/62143/8
    
    """
    #It’s common to call the total number of processes the world size

    def __init__(self,indices):
        self.indices = indices
    
    def __iter__(self):
        # deterministically shuffle based on epoch
        x=[self.indices[i] for i in torch.randperm(len(self.indices))]
        return iter(x)
    def __len__(self):
        return len(self.indices)

    def set_epoch(self, epoch):
        self.epoch = epoch

Just make sure you use set_epoch() for every epoch to shuffle the data
Let me know if it was helpful !

Thank you for your answer!
I think it’s a little different, but I’ll try it.