CUDA OOM when using pretrained vectors

I’m training a CNN for text classification ( When I try updating 300d pretrained embeddings, I get an OOM error when back propagating (on calling optim.step()) regardless of batch size. However, if the embeddings are frozen or the embeddings are learned there are no issues. I suspect this is due to the way they are being loaded. Below is the embedding loading code.

	def _loadEmbeddings(self, vecFilePath):
		# import pdb;pdb.set_trace()
		print("Loading pretrained embeddings: ", vecFilePath)
		with open(vecFilePath, encoding="utf8") as f:
			readLen = self.vocab.MAX_VOCAB_SIZE
			if readLen == -1:
				lines = [line for line in f.readlines()]
				lines = [next(f) for x in range(readLen)]
			words = [line.split()[0] for line in lines]
			self.vocab = Vocab(MAX_VOCAB_SIZE = self.vocab.MAX_VOCAB_SIZE, MAX_IDX = 1, UNK_IDX=1)
			embeddings = np.zeros((len(self.vocab), self.embedding_dim)) # zeros or rand?
			for line in lines:
				values = line.split()
				word = values[0]
				index = self.vocab(word)
				if index:
					vector = np.array(values[1:], dtype='float32')
					embeddings[index] = vector
			embeddings[1] = np.mean(embeddings[2:], axis=0)
			return nn.Embedding.from_pretrained(torch.from_numpy(embeddings))

Does anyone have any ideas as to possible causes or tools that could help debug this issue?

Are you using the same vocab size and embedding dimensions for both use cases (pretrained and trained from scratch)?
If the dimensions are the same, the code should not create an OOM error in one use case while the other is working fine.

The dimensions are the same, however the pretrained embedding vocab size is ~4x larger (400,000 vs 100,000). Does Pytorch update the entire embedding at once? Otherwise I don’t see how having a larger embedding could be an issue (a side from having more memory consumed, but what I was seeing seemed more like an explosion in memory usage).

Note: I messed around with it some more and suspect it is a bug in the vocab’s code. If the vocab size isn’t explicitly state (Instead of a size you can pass -1 and it acts as if you were getting the last index of your array) that’s when the issue occurs.