Problem definition:
I have a dataset with an associated dataloader which I use in a distributed fashion like below:
train_dataset = datasets.ImageFolder(traindir, transform=custom_transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
Let’s say the images in the dataset are associated with an index (not a class label). The index ‘n’ associates with the ‘nth’ image in the dataset.
I have a model that takes in an image batch(first_img_batch) and returns some indices. These indices are the indices of images in the dataset which are most similar to the input batch according to some similarity metric. Now I want the dataset (dataloader) to get me back the images(second_img_batch) associated with these indices.
First workaround:
for epoch in range(args.start_epoch, args.stop_epoch):
sampler.set_epoch(epoch)
for first_img_batch in train_loader:
indices = model(images)
# now we want to get the images associated to these indices back How do we do that?
second_img_batch = train_loader.dataset[indices]
# do something with the second_images
Concerns:
1- I have multiple workers.
2- I am in DDP mode.
@ptrblck any ideas?
I am not sure if this is thread-safe or cross GPU safe. Also, it imposes a huge load on mu CPU cores.
Second workaround:
I tried to create a sampler class like this
second_sampler = Second_Sampler(train_dataset)
second_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
shuffle=(second_sampler is None),
sampler=second_sampler,
pin_memory=True,
num_workers=0, #please note here
batch_size=args.batch_size,
)
class Second_Sampler(torch.utils.data.distributed.DistributedSampler):
def __init__(self, dataset, num_positives=4):
self.num_positives = num_positives
self.dataset = dataset
super().__init__(dataset= self.dataset)
self.num_samples = self.num_positives
self.shuffle = False
def __iter__(self) -> Iterator[T_co]:
indices = self.get_sample()
return iter(indices)
def set_sample(self, indices):
if torch.is_tensor(indices):
indices = indices.tolist()
self.indices = indices
def get_sample(self):
return self.indices
Now the code becomes:
for epoch in range(args.start_epoch, args.stop_epoch):
sampler.set_epoch(epoch)
for first_img_batch in train_loader:
indices = model(images)
# now we want to get the images associated to these indices back How do we do that?
second_sampler.set_sample(indices)
second_img_batch = next(iter(second_loader))
# do something with the second_images
This workaround also is very time-consuming.
Please help me and get me out of agony.