Triplet sampler and memory leaks

Hi all!

I have implemented some random triplet sampler in the following way:

class TripletSampler(Sampler):
    def __init__(self, dataset, batch_size=1, drop_last=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.triplets = None
        self.sampler = SequentialSampler(self.dataset)

    def __iter__(self):
        labels = self.dataset.labels
        batch = []
        triplets = []

        for r_idx in self.sampler:
            p_idx = np.random.choice(np.where(labels == labels[r_idx])[0], 1)[0]
            n_idx = np.random.choice(np.where(labels != labels[r_idx])[0], 1)[0]

            batch.append([r_idx, p_idx, n_idx])
            if len(batch) == self.batch_size:
                yield batch
                triplets.append(batch)
                batch = []
                gc.collect()

        if len(batch) > 0 and not self.drop_last:
            yield batch
            triplets.append(batch)

        self.triplets = triplets

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

While I am training the model, the GPU memory usage slowly increases until it has not enough memory to continue running. If I do not use the sampler and instead pre-calculate the triplets and then feed it to the dataset it works fine.
I tried adding some gc.colllector() but it memory is still increasing at a really slow rate. Any idea on why is this happening?

If needed, this is the implementation of the __getitem__ method inside the dataset (hope it is self-understandable):

def __getitem__(self, index):

    # get the triplet indexes and use them to obtain their paths
    r_idx, p_idx, n_idx = index

    R_path, P_path, N_path = self.data_path[r_idx], self.data_path[p_idx], self.data_path[n_idx]

    # load the images in the triplet 
    R, P, N = data_utils.read_exr(R_path), data_utils.read_exr(P_path), data_utils.read_exr(N_path)

    # load the labels of the images with its names
    R_l, P_l, N_l = self.labels[r_idx], self.labels[p_idx], self.labels[n_idx]

    # apply transforms to all the images in the triplet
    if self.transforms is not None:
        R, P, N = self.transforms(R), self.transforms(P), self.transforms(N)

    # return the images in the triplet
    return R, P, N, R_l, P_l, N_l