Hi, I was trying to implement a custom sampler.
But I get a memory out of memory error on my GPU system.
Here is the code.
from torch.utils.data.sampler import Sampler
class SSGDSampler(Sampler):
r"""Samples elements according to SSGD Sampler
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source, model, batch_size):
self.training_data = data_source.train_data.to(device)
self.training_label = data_source.train_labels.to(device)
self.training_data = self.training_data.view(self.training_data.shape[0], 1, self.training_data.shape[1], self.training_data.shape[2])
self.training_data = self.training_data.type(torch.cuda.FloatTensor)
self.model = model
self.batch_size = batch_size
def compute_score(self):
sampled=[]
print(self.training_data.shape)
output = model(self.training_data)
loss = F.cross_entropy(output, self.training_label, reduce=False)
prob = F.softmax(loss)
feat = model.feat
for _ in range(0, self.batch_size):
if len(sampled)==0:
sampled.extend(torch.argmax(prob))
else:
dist = torch.mm(self.feat, self.feat[sampled].T)
min_dist = torch.min(dist, dim=0)
mean_dist = torch.mean(dist, dim=0)
score = min_dist + mean_dist + prob
max_idx = torch.argmax(score)
sampled.extend(max_idx)
return sampled
def __iter__(self):
sampled=self.compute_score()
print(sampled)
return iter(sampled)
def __len__(self):
return len(self.data_source)
One solution would be to sample in batches. But are there any better solutions. Also any better solutions in general for the sampler i want to create since this is 30 min hacking on pytorch.