I’m using an LSTM, and this error keeps occurring when the hidden state size is larger than 16. At 16 the network outputs nan instead.
The model is defined as:
class LSTM(nn.Module):
def __init__(self, input_dim, hidden_dim, batch_size, output_dim=1, num_layers=1, dropout=0, h0=None, c0=None):
super(LSTM, self).__init__()
self.input_dim = input_dim #n_row*n_col
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.output_dim = output_dim
self.num_layers = num_layers
self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.num_layers, dropout=dropout)
self.decoder = nn.Linear(self.hidden_dim, self.output_dim)
#Initialize hidden states, default is zero
if h0 is None:
self.h0 = torch.zeros(self.num_layers, self.batch_size, self.hidden_dim)
self.c0 = torch.zeros(self.num_layers, self.batch_size, self.hidden_dim)
else:
self.h0 = h0
self.c0 = c0
if torch.cuda.is_available():
self.h0 = self.h0.cuda()
self.c0 = self.c0.cuda()
#Forward pass
def forward(self, input):
#Input to LSTM has shape (seq_length, batch_size, n_row*n_col)
#LSTM output has shape (seq_length, batch_size, hidden_dim)
lstm_out, self.hidden = self.lstm(input, (self.h0, self.c0))
#Decoder output has shape (seq_length, batch_size, output_dim)
prediction = self.decoder(lstm_out)
return prediction, self.hidden
#Propagate one step in forward pass
def step(self, input, h, c):
#Input to LSTM has shape (1, batch_size, n_row*n_col)
#LSTM output has shape (1, batch_size, hidden_dim)
if torch.cuda.is_available():
h = h.cuda()
c = c.cuda()
input = input.cuda()
lstm_out, self.hidden = self.lstm(input, (h, c))
# Decoder output has shape (1, batch_size, output_dim)
prediction = self.decoder(lstm_out)
return prediction, self.hidden
and trained with:
loss_fn = nn.MSELoss()
optimizer = optim.Adam(params=model.parameters(), lr=lr, weight_decay=weight_decay)
for epoch in range(n_epochs):
loss_total = 0
for batch_idx, batch in enumerate(dataloader):
if batch_idx == end_idx:
break
input_seq = Variable(batch["input"].view(seq_length, batch_size, -1))
output_seq = Variable(batch["output"].view(seq_length, batch_size, -1))
if torch.cuda.is_available():
input_seq = input_seq.cuda()
output_seq = output_seq.cuda()
optimizer.zero_grad()
_, (h, c) = model(input_seq[0:seen_step])
empty_input = torch.zeros_like(input_seq[0:1])
fut_prediction = []
for t in range(fut_step):
prediction, (h, c) = model.step(empty_input, h, c)
fut_prediction.append(prediction)
pred_seq = torch.cat(fut_prediction, dim=0)
truth_seq = output_seq[seen_step:]
loss = loss_fn(pred_seq, truth_seq)
loss_total += loss.detach().item()
loss.backward()
if grad_clip != 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()