Hi all!
I have implemented some random triplet sampler in the following way:
class TripletSampler(Sampler):
def __init__(self, dataset, batch_size=1, drop_last=True):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.triplets = None
self.sampler = SequentialSampler(self.dataset)
def __iter__(self):
labels = self.dataset.labels
batch = []
triplets = []
for r_idx in self.sampler:
p_idx = np.random.choice(np.where(labels == labels[r_idx])[0], 1)[0]
n_idx = np.random.choice(np.where(labels != labels[r_idx])[0], 1)[0]
batch.append([r_idx, p_idx, n_idx])
if len(batch) == self.batch_size:
yield batch
triplets.append(batch)
batch = []
gc.collect()
if len(batch) > 0 and not self.drop_last:
yield batch
triplets.append(batch)
self.triplets = triplets
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
While I am training the model, the GPU memory usage slowly increases until it has not enough memory to continue running. If I do not use the sampler and instead pre-calculate the triplets and then feed it to the dataset it works fine.
I tried adding some gc.colllector()
but it memory is still increasing at a really slow rate. Any idea on why is this happening?
If needed, this is the implementation of the __getitem__
method inside the dataset (hope it is self-understandable):
def __getitem__(self, index):
# get the triplet indexes and use them to obtain their paths
r_idx, p_idx, n_idx = index
R_path, P_path, N_path = self.data_path[r_idx], self.data_path[p_idx], self.data_path[n_idx]
# load the images in the triplet
R, P, N = data_utils.read_exr(R_path), data_utils.read_exr(P_path), data_utils.read_exr(N_path)
# load the labels of the images with its names
R_l, P_l, N_l = self.labels[r_idx], self.labels[p_idx], self.labels[n_idx]
# apply transforms to all the images in the triplet
if self.transforms is not None:
R, P, N = self.transforms(R), self.transforms(P), self.transforms(N)
# return the images in the triplet
return R, P, N, R_l, P_l, N_l