I have implement a Sampler
when using single GPU. I implement it because I want all samples in a batch are from the same source.
class MultiSourceBatchSampler(torch.utils.data.Sampler):
def __init__(
self,
data_source: Dataset,
# sources: list,
batch_size: int,
shuffle: bool = False,
drop_last: bool = False,
):
self.data_source = data_source
# self.sources = sources
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self._length = -1
def __iter__(self):
# indices are tuples (source, second_key)
indices = self.data_source._get_indices()
if self.shuffle:
np.random.shuffle(indices)
indices.sort(key=lambda x: x[0]) #sort by source
batched_indices = []
buffer = []
for index in indices:
if (len(buffer) == self.batch_size) or (len(buffer) > 0 and buffer[0][0] != index[0]):
batched_indices.append(buffer)
buffer = []
buffer.append(index)
if len(buffer) > 0 and not self.drop_last:
batched_indices.append(buffer)
if self.shuffle:
np.random.shuffle(batched_indices)
yield from batched_indices
def __len__(self):
if self._length >= 0:
return self._length
total = 0
indices = self.data_source._get_keys(by=self.by)
indices.sort(key=lambda x: x[0]) #sort by source
buffer = []
for index in indices:
if (len(buffer) == self.batch_size) or (len(buffer) > 0 and buffer[0][0] != index[0]):
total += 1
buffer = []
buffer.append(index)
if len(buffer) > 0 and not self.drop_last:
total += 1
self._length = total
return total
I did not find very detailed documentation about DistributedSampler
, so I wonder what is the principle if I want to convert my sampler to a DistributedSampler