Pytorch LSTM syntax with/without passing hidden state

I have a question regarding LSTMs.

I have an LSTM class that is defined as follows. I’ll call it Version I:

class LSTMTagger(nn.Module):
	def __init__(self, n_tag = 2):
		super(LSTMTagger, self).__init__()
		self.hidden_dim = args.mlp_hidden_dim
		self.word_embeddings = nn.Embedding(VOCAB_SIZE, args.embedding_dim)
		self.lstm = nn.LSTM(input_size = args.embedding_dim, hidden_size = args.rnn_dim, batch_first=True)
		#some linear layers....
		self.hidden = self.init_hidden()


	def init_hidden(self):
		# Before we've done anything, we dont have any hidden state.
		# The axes semantics are (num_layers * num_directions, minibatch_size, hidden_dim)
		return (Variable(torch.zeros(1, 1, self.hidden_dim)),   
				Variable(torch.zeros(1, 1, self.hidden_dim)))    

	def forward(self, x):
		x = self.word_embeddings(x)
		x = x.unsqueeze(0)
		x, self.hidden = self.lstm(x, self.hidden)
		#... more code that takes x through the linear layers... blah blah blah...
		y = F.softmax(x, dim=1)
		return y 

At the beginning of every iteration (i.e., every new sequence), I call:

model.hidden = model.init_hidden()

Very quickly, I get to the following error:

RuntimeError: CuDNN error: CUDNN_STATUS_EXECUTION_FAILED

I tried to define the same LSTM class somewhat differently. I’ll call it Version II:

class LSTMTagger(nn.Module):
	def __init__(self, n_tag = 2):
		super(LSTMTagger, self).__init__()
		self.word_embeddings = nn.Embedding(VOCAB_SIZE, args.embedding_dim)
		self.lstm = nn.LSTM(input_size = args.embedding_dim, hidden_size = args.rnn_dim, batch_first=True)
		#some linear layers....

	def forward(self, x):
		x = self.word_embeddings(x)
		x = x.unsqueeze(0)
		x, _ = self.lstm(x)
		#blah blah blah like before
		y = F.softmax(x, dim=1)
		return y 

The difference is that this time I didn’t initialize the hidden states, and I’m not passing them on to self.lstm(). This version works.

My question is: why does it work, & how? I can’t understand why version I seems to be exploding (I guess), and how does the second version work without it receiving the previous hidden values?

Thanks in advance

1 Like