Memory leak in dataloader?

I’m trying to train a CNN using a custom dataset and I’ve noticed that CPU memory usage balloons (in spite of the fact that the data and model are on a GPU). The code has been stripped away to the point that the only things left are a loop that gets samples from the dataloader and for loop for the epochs.

Version of Pytorch: 1.1.0
TorchVision: 0.3.0

Even calling the garbage collector doesn’t help much. The only thing that does seem to help is to iterate through all the objects in the GC. I’m not sure if this is really a solution, but the memory leak seems to be much less important when I add that into the code. This was done in an attempt to instrument the code and try to identify if there were more tensors being added.

def train_model(model, datasets, criterion, optimizer, scheduler, device, num_epochs=25, numIterStop = 7):
	setLabels = ['train','val']
	trainingMonitor = TrainingMonitor(model, numIterStop = numIterStop)
	#import pdb; pdb.set_trace()
	for epoch in range(num_epochs):
		print('Epoch {}/{}'.format(epoch, num_epochs - 1))
		print('-' * 10)
		for setLabel in setLabels:
			if setLabel == 'train':
				model.train()  # Set model to training mode
				model.eval()   # Set model to evaluate mode

			running_loss = 0.0
			running_corrects = 0

			# Iterate over data.
			for inputs, labels	in datasets['loader'][setLabel]:
				inputs =
				labels =
				# zero the parameter gradients

				# forward
				# track history if only in train
				with torch.set_grad_enabled(setLabel == 'train'):
					outputs = model(inputs)
					_, preds = torch.max(outputs, 1)
					loss = criterion(outputs, labels)
					del outputs
					if setLabel == 'train':

				# statistics
				running_loss += loss.item() * inputs[0].size(0) # Is this the batch size, if so why?
				del inputs
				del loss
				running_corrects += torch.sum(preds ==
				del labels
				del preds
			epoch_loss = running_loss / datasets['size'][setLabel]
			epoch_acc = running_corrects.double() / datasets['size'][setLabel]
			trainingMonitor.recordData(setLabel,epoch_loss, epoch_acc)
			print('{} Loss: {:.4f} Acc: {:.4f}'.format(
				setLabel, epoch_loss, epoch_acc))
			# deep copy the model when best accuracy and loss is lower (implement early stop)
			if setLabel == 'val':
				if not trainingMonitor.evaluateEpoch(model,epoch_loss,epoch_acc):
					return trainingMonitor

	time_elapsed = trainingMonitor.timeElapsed()
	print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
	print('Best val Acc: {:4f}'.format(trainingMonitor.best_acc))
	# load best model weights
	return trainingMonitor

Here is how it is loaded

		dataset = ImageDataset() #TextDataset(glove_path, embeddingDim = params['FeatureSize']) #SplitDataset(shuffleData=True) #ImageDataset()
		trainingDataset = copy.copy(dataset)
		valDataset = copy.copy(dataset)
		trainingDL =, batch_size=params["batch_size"],shuffle=True, num_workers=0) # Work this into your model
		validationDL =, batch_size=params["batch_size"],shuffle=True, num_workers=0)
		DatasetSizes = {"train": len(trainingDataset),"val": len(valDataset)}
		Dataloaders = {"train": trainingDL,"val":validationDL}
		Dataset = {"size":DatasetSizes,"loader":Dataloaders}

Below is the code for the dataset.

class SplitDataset(MongoDataset):
	def __init__(self, trainingProportion = 0.8, col = "Profile"):
		MongoDataset.__init__(self, col=col)
		trainingSetSize = int(len(self)*trainingProportion)
		self.phase = 'train'
		self.trainingSet = self.ids[0:trainingSetSize]
		self.validationSet = self.ids[trainingSetSize:]
	def selectPhase(self, phase):
		raise NotImplementedError

class ImageDataset(SplitDataset):
	def __init__(self, trainingProportion = 0.8, col = "Profile"):
		SplitDataset.__init__(self,trainingProportion = trainingProportion, col = col)
		self.labels = []
		self.trainingData,self.trainingLabels = self._concatenateImages(self.trainingSet)
		self.valData,self.valLabels = self._concatenateImages(self.validationSet)

	def _concatenateImages(self, ids):
		imageSet = []
		labelSet = []
		for i in range(len(ids)):
			profile = self.mongoDB.restore_profile(self.loadData[i], col = self.col)
			[imageSet.append(img) for img in profile.images]
			[labelSet.append(profile.liked) for img in profile.images]
		return imageSet, labelSet

	def __getitem__(self, i):
		return self.imageTransform(self.loadData[i]),int(self.labels[i])
	def selectPhase(self, phase):
		if phase == 'val':
			self.loadData = self.valData
			self.labels = self.valLabels
			self.transform = ProfileDataset.validationTransform
		elif phase == 'train':
			self.loadData = self.trainingData
			self.labels = self.trainingLabels
			self.transform = ProfileDataset.trainingTransform
			raise("Invalid phase")

I did see a bug on Github that was similar, however it only seemed to occur when numWorkers > 0 (not the case here) and when lists are used for the data (I’m using a list to aggregate all the images and line them up with the correct label).

Am I missing something obvious? If not, how could I instrument my code better? So far I’ve been using resource monitoring tools and pdb to try and track down the issue.

1 Like