Slow CNN code: possible usage of CPU

I’m trying to reimplement Yoon Kim’s 2014 CNN sentence classification paper (https://arxiv.org/abs/1408.5882) and I’ve run into speed issue. It currently takes 18 hours (according to TQDM) to complete an epoch. The output of Pytorch’s profiler is as follows:

----------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------
Name                          Self CPU total %  Self CPU total   CPU total %      CPU total        CPU time avg     CUDA total %     CUDA total       CUDA time avg    Number of Calls  Input Shapes               
----------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------
EmbeddingBackward             6.85%            1.461s           6.85%            1.461s           1.461s           3.96%            3.616ms          3.616ms          1                []                          
embedding_backward            6.85%            1.461s           6.85%            1.461s           1.461s           3.95%            3.612ms          3.612ms          1                []                          
embedding_dense_backward      6.85%            1.460s           6.85%            1.460s           1.460s           3.87%            3.536ms          3.536ms          1                []                          
EmbeddingBackward             6.84%            1.458s           6.84%            1.458s           1.458s           10.68%           9.764ms          9.764ms          1                []                          
embedding_backward            6.84%            1.458s           6.84%            1.458s           1.458s           10.68%           9.760ms          9.760ms          1                []                          
embedding_dense_backward      6.84%            1.458s           6.84%            1.458s           1.458s           10.26%           9.384ms          9.384ms          1                []                          
EmbeddingBackward             6.75%            1.439s           6.75%            1.439s           1.439s           11.12%           10.170ms         10.170ms         1                []                          
embedding_backward            6.75%            1.439s           6.75%            1.439s           1.439s           11.12%           10.166ms         10.166ms         1                []                          
embedding_dense_backward      6.75%            1.439s           6.75%            1.439s           1.439s           10.67%           9.754ms          9.754ms          1                []                          
EmbeddingBackward             6.45%            1.376s           6.45%            1.376s           1.376s           3.84%            3.508ms          3.508ms          1                []                          
embedding_backward            6.45%            1.376s           6.45%            1.376s           1.376s           3.84%            3.508ms          3.508ms          1                []                          
embedding_dense_backward      6.45%            1.376s           6.45%            1.376s           1.376s           3.78%            3.452ms          3.452ms          1                []                          
EmbeddingBackward             6.45%            1.374s           6.45%            1.374s           1.374s           4.10%            3.746ms          3.746ms          1                []                          
embedding_backward            6.45%            1.374s           6.45%            1.374s           1.374s           4.09%            3.740ms          3.740ms          1                []                          
embedding_dense_backward      6.45%            1.374s           6.45%            1.374s           1.374s           4.06%            3.712ms          3.712ms          1                []                          
----------------------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  ---------------  -----------------------------------
Self CPU time total: 21.323s
CUDA time total: 91.428ms

The issue appears to be related to use of the CPU during backprop, however the model is explicitly moved to the GPU in the code. Why might this happen?

Below is the code for the model:

class SentenceClassifierCNN(WordClassifier):
	def __init__(self, vocab, processor, embeddings = None, **args):
		super().__init__(vocab, processor, embeddings, **args)

		self.model_type = args['model_type']
		# Filters are tuples, first is the number filters second is the filter size
		self.filters = args['filters']
		self.max_sent_len = args.get('max_sent_len','random')
		self.dropout = args.get('dropout',0)
		
		if self.model_type == 'multichannel':
			self.embedding2 = nn.Embedding(len(vocab), args['embedding_dim'])
		elif self.model_type == 'non-static':
			self.embedding.weight.requires_grad = True
		else:
			pass
		#import pdb;pdb.set_trace()
		self.conv_layers = nn.ModuleList([nn.Conv1d(in_channels=1, out_channels=filter_size[1], kernel_size=filter_size[0]* self.embedding_dim) for filter_size in self.filters])
		self.actFunc = torch.nn.ReLU()
		self.projection_layer = torch.nn.Linear(in_features=sum([filter[1] for filter in self.filters]), out_features=args['output_dim'], bias=True)
		torch.nn.init.xavier_uniform_(self.projection_layer.weight)

	def __call__(self, sentences):
		#import pdb;pdb.set_trace()
		input = self._embed(sentences, max_len = self.max_sent_len).cuda()
		emb = input.view(-1, 1, self.embedding_dim * self.max_sent_len).cuda() # flatten embeddings
		if self.model_type == 'multichannel':
			emb2 = self.embedding2(input, max_len = self.max_sent_len, embedding=self.embedding2).view(-1, 1, self.embedding_dim * self.max_sent_len)
			emb = torch.cat((emb, emb2), 1)
			
		conv_results = [torch.max(F.relu(conv_layer(emb)), dim=2)[0] for conv_layer in self.conv_layers]
		flattened_convs = torch.cat(conv_results, 1)
		dropped_out_convs = F.dropout(flattened_convs, p=self.dropout, training=self.training)
		return self.projection_layer(dropped_out_convs)