Custom Sampler in Pytorch

Hi, I was trying to implement a custom sampler.
But I get a memory out of memory error on my GPU system.
Here is the code.

from torch.utils.data.sampler import Sampler
class SSGDSampler(Sampler):
    r"""Samples elements according to SSGD Sampler
    Arguments:
        data_source (Dataset): dataset to sample from
    """

    def __init__(self, data_source, model, batch_size):
        self.training_data = data_source.train_data.to(device)
        self.training_label = data_source.train_labels.to(device)
        self.training_data = self.training_data.view(self.training_data.shape[0], 1, self.training_data.shape[1], self.training_data.shape[2])
        self.training_data = self.training_data.type(torch.cuda.FloatTensor)
        self.model = model
        self.batch_size = batch_size


    def compute_score(self):
        sampled=[]
        print(self.training_data.shape)
        output = model(self.training_data)
        loss = F.cross_entropy(output, self.training_label, reduce=False)
        prob = F.softmax(loss)
        feat = model.feat
        for _ in range(0, self.batch_size):
            if len(sampled)==0:
                 sampled.extend(torch.argmax(prob))
            else:
                dist = torch.mm(self.feat, self.feat[sampled].T)
                min_dist = torch.min(dist, dim=0)
                mean_dist = torch.mean(dist, dim=0)
                score = min_dist + mean_dist + prob
                max_idx = torch.argmax(score)
                sampled.extend(max_idx)
        return sampled

    def __iter__(self):
        sampled=self.compute_score()
        print(sampled)
        return iter(sampled)

    def __len__(self):
        return len(self.data_source)

One solution would be to sample in batches. But are there any better solutions. Also any better solutions in general for the sampler i want to create since this is 30 min hacking on pytorch.

Resolved here. https://github.com/pytorch/pytorch/issues/7359

1 Like