Setting dataset property while loading the dataloader with multiple workers

Problem definition:
I have a dataset with an associated dataloader which I use in a distributed fashion like below:

 train_dataset = datasets.ImageFolder(traindir, transform=custom_transform)
 train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
 train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)

Let’s say the images in the dataset are associated with an index (not a class label). The index ‘n’ associates with the ‘nth’ image in the dataset.
I have a model that takes in an image batch(first_img_batch) and returns some indices. These indices are the indices of images in the dataset which are most similar to the input batch according to some similarity metric. Now I want the dataset (dataloader) to get me back the images(second_img_batch) associated with these indices.

First workaround:

for epoch in range(args.start_epoch,  args.stop_epoch):
    sampler.set_epoch(epoch)
    for first_img_batch in train_loader:
        indices = model(images)
        # now we want to get the images associated to these indices back How do we do that?
        second_img_batch = train_loader.dataset[indices]
        # do something with the second_images

Concerns:
1- I have multiple workers.
2- I am in DDP mode.
@ptrblck any ideas?

I am not sure if this is thread-safe or cross GPU safe. Also, it imposes a huge load on mu CPU cores.

Second workaround:

I tried to create a sampler class like this

second_sampler = Second_Sampler(train_dataset)
second_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        shuffle=(second_sampler is None),
        sampler=second_sampler,
        pin_memory=True,
        num_workers=0,  #please note here
        batch_size=args.batch_size,
    )

class Second_Sampler(torch.utils.data.distributed.DistributedSampler):
    def __init__(self, dataset, num_positives=4):
        self.num_positives = num_positives
        self.dataset = dataset
        super().__init__(dataset= self.dataset)
        self.num_samples = self.num_positives
        self.shuffle = False

    def __iter__(self) -> Iterator[T_co]:
        indices = self.get_sample()
        return iter(indices)

    def set_sample(self, indices):
        if torch.is_tensor(indices):
            indices = indices.tolist()
        self.indices = indices

    def get_sample(self):
        return self.indices

Now the code becomes:

for epoch in range(args.start_epoch,  args.stop_epoch):
    sampler.set_epoch(epoch)
    for first_img_batch in train_loader:
        indices = model(images)
        # now we want to get the images associated to these indices back How do we do that?
        second_sampler.set_sample(indices)
        second_img_batch = next(iter(second_loader))
        # do something with the second_images

This workaround also is very time-consuming.

Please help me and get me out of agony. :exploding_head:

Thanks for following up on this issue.
I don’t think the general workflow changes in the distributed setup, but think it’s limited by the way multiprocessing works.
The issue I’m seeing is that you need to manipulate the indices passed to Dataset.__getitem__ in each iteration.
The usual workflow would be to create the Dataset (with a custom sampler), setup the DataLoader, and iterate it for a complete epoch. If you are using multiprocessing, each worker will create a copy of the dataset, so that manipulating the underlying loader.dataset through the DataLoader won’t reflect the changes until the next epoch.
I’m not 100% sure about the behavior of the sampler, but would also assume it’s copied using num_workers>0.

Maybe a valid approach would be to use a shared array (e.g. via mp.Array as seen in this example), use it in the custom sampler, and apply the changes on the fly. However, now the distributed setup might be critical and you might need to use a locking mechanism to synchronize the manipulation of this array.
I hope someone else has a better idea how this use case can be properly implemented.

@ptrblck
Thanks for the reply, I will investigate more on that.
I created a dummy example like what you did before,

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.ones(10)
        self.num = 10

    def __getitem__(self, index):
        self.data[index] = self.get_sample()
        return self.data[index]

    def __len__(self):
        return len(self.data)

    def set_sample(self, num):
        self.num = num

    def get_sample(self):
        return self.num

num_workers = 10
dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=num_workers,
    shuffle=False
)
for idx, data in enumerate(loader):
    if idx == 3:
        loader.dataset.set_sample(12)
    print(f'{idx}  {data}, num_workers: {num_workers}')

print('-----------------------------------------------')

for idx, data in enumerate(loader):
    if idx == 3:
        loader.dataset.set_sample(12)
    print(f'{idx}  {data}, num_workers: {num_workers}')


num_workers =1
dataset = MyDataset()
loader = DataLoader(
    dataset,
    num_workers=num_workers,
    shuffle=False
)

print('-----------------------------------------------')
for idx, data in enumerate(loader):
    if idx == 3:
        loader.dataset.set_sample(12)
    print(f'{idx}  {data}, num_workers {num_workers}')
print('-----------------------------------------------')

for idx, data in enumerate(loader):
    if idx == 3:
        loader.dataset.set_sample(12)
    print(f'{idx}  {data}, num_workers {num_workers}')


I expect data[3] to data [9] be 12 but did’nt happen.

Just for future reference if anyone wants to use it:
The following approach works perfectly fine for multiple workers too.


class MyDataset(Dataset):
    def __init__(self, num_workers, num):
        self.num = num
        self.num_workers = num_workers
        self.nb_samples = 20
        self.data = np.ones(self.nb_samples)

    def __getitem__(self, index):
        self.data[index] = self.num
        return self.data[index]

    def __len__(self):
        return self.nb_samples

    def set_sample(self, num):
        self.num = num


num_workers = 12
dataset = MyDataset(num_workers, 8)
loader = DataLoader(dataset, num_workers=num_workers, shuffle=False)
idx = 0
while idx < loader.__len__():
    if idx == 3:
        dataset.set_sample(12)
    data = next(iter(loader))
    print(f'{idx}  {data}, num_workers: {num_workers}')
    idx += 1

0  tensor([8.]), num_workers: 12
1  tensor([8.]), num_workers: 12
2  tensor([8.]), num_workers: 12
3  tensor([12.]), num_workers: 12
4  tensor([12.]), num_workers: 12
5  tensor([12.]), num_workers: 12
6  tensor([12.]), num_workers: 12
7  tensor([12.]), num_workers: 12
8  tensor([12.]), num_workers: 12
9  tensor([12.]), num_workers: 12
10  tensor([12.]), num_workers: 12
11  tensor([12.]), num_workers: 12
12  tensor([12.]), num_workers: 12
13  tensor([12.]), num_workers: 12
14  tensor([12.]), num_workers: 12
15  tensor([12.]), num_workers: 12
16  tensor([12.]), num_workers: 12
17  tensor([12.]), num_workers: 12
18  tensor([12.]), num_workers: 12
19  tensor([12.]), num_workers: 12

In your example you are recreating the iterator and the next operation would return the same data sample inside the loop, so I don’t think it’s the right approach.

Well, I meant to show this property changing method works.

This is the complete solution I came up with for my main question:

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.utils.data.distributed
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import numpy as np
import argparse




class MyDataset(Dataset):
    def __init__(self):
        self.nb_samples = 300
        self.data = np.array([x + .2 for x in range(0, self.nb_samples)])
        self.label = np.array([x + .1 for x in range(0, self.nb_samples)])
        self.index = np.array([x  for x in range(0, self.nb_samples)])

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        return torch.tensor(self.data[index]), torch.tensor(self.label[index]), torch.tensor(self.index[index])

    def __len__(self):
        return self.nb_samples

    def get_sample(self, index):
        print(f'INDEXs: {index}')
        if torch.is_tensor(index):
            index = index.flatten().tolist()
        print(f'INDEX: {index}')
        return torch.tensor(self.data[index]), torch.tensor(self.label[index]), torch.tensor(self.index[index])


def main(device, args):
    ngpus_per_node = torch.cuda.device_count()
    args.world_size = torch.cuda.device_count() * args.world_size
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))


def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
  
    args.rank = args.rank * ngpus_per_node + gpu
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size, rank=args.rank)
    torch.cuda.set_device(args.gpu)
    args.batch_size = int(args.batch_size / ngpus_per_node)
    args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

    dataset = MyDataset()
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    loader = DataLoader(dataset,
                        num_workers=args.workers,
                        shuffle=(sampler is None),
                        sampler=sampler,
                        batch_size = args.batch_size,
                        pin_memory=True)


    for epoch in range(0, 1):
        sampler.set_epoch(epoch)
        for data, label, ind in loader:
            data = data.cuda(args.gpu)
            label = label.cuda(args.gpu)
            ind = ind.cuda(args.gpu)

            print(f'D: {data},  {data.device}')
            print(f'I: {ind} {ind.device}',)
            print(f'L: {label} {label.device}')

            # take a random index to retrieve from the dataset

            rand_ind = torch.randint(300, (2,)).cuda(args.gpu)
            print('###############################')
            print(f'Random Index to Retrieve: {rand_ind}')
            print('###############################')
            datas, labels,inds = loader.dataset.get_sample(rand_ind)

            print(f'DS: {datas},  {data.device}')
            print(f'IS: {inds} {ind.device}')
            print(f'LS: {labels} {label.device}')
            print('========================================')


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
    parser.add_argument('--world-size', default=1, type=int,
                        help='number of nodes for distributed training')
    parser.add_argument('--rank', default=0, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--dist-url', default='tcp://127.0.0.1:50000', type=str,
                        help='url used to set up distributed training')
    parser.add_argument('--dist-backend', default='nccl', type=str,
                        help='distributed backend')
    parser.add_argument('--gpu', default=None, type=int,
                        help='GPU id to use.')
    parser.add_argument('-b', '--batch-size', default=16, type=int,
                       metavar='N',
                       help='mini-batch size (default: 256), this is the total '
                            'batch size of all GPUs on the current node when '
                            'using Data Parallel or Distributed Data Parallel')
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args, unparsed = parser.parse_known_args()

    arg = vars(args)
    for item in arg:
        print('{:.<30s}{:<50s}'.format(str(item), str(arg[item])))

    print('-------------------------------------------------------------------------------------------------')
    main(device=args.gpu, args=args)