Error on loss.backward() with RNN

This might be really dumb and the solution is probably obvious but I just can’t find it. I’m trying to train a char-rnn model but I’m getting a RuntimeError in the loss.backward() call:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

My model looks like this:

class CharRNN(nn.Module):
  def __init__(self, vocab_size, embedding_dim=256, rnn_units=512, num_layers=3, dropout=0):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.rnn = nn.GRU(embedding_dim, rnn_units, num_layers, dropout=dropout, batch_first=True)
    self.decoder = nn.Linear(rnn_units, vocab_size)
  
  def forward(self, inputs, state=None):
    x = self.embedding(inputs)
    output, state = self.rnn(x, state)
    output = self.decoder(output)
    return output, state

Custom loss class (I don’t know if this is correct, it’s supposed to be CrossEntropyLoss for inputs of shape (batch_size, sequence_length, vocab_size)):

class CategoricalCrossEntropyLoss(nn.Module):
  def __init__(self):
    super().__init__()
    self.log_softmax = nn.LogSoftmax(-1)
  
  def forward(self, inputs, targets):
    log_probs = self.log_softmax(inputs)
    log_probs = log_probs.gather(-1, targets.unsqueeze(-1)).squeeze(-1)
    return -log_probs.mean()

And my training loop looks like this:

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = CharRNN(vocab_size)
model.to(device)
loss_fn = CategoricalCrossEntropyLoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.002, alpha=0.95)
scheduler = StepLR(optimizer, 5, 0.95)

for epoch in range(50):
  running_loss = 0.0
  state = None
  model.train(True)
  for data in train_loader:
    inputs, targets = data[0].to(device), data[1].to(device)
    optimizer.zero_grad()
    output, state = model(inputs, state)
    loss = loss_fn(output, targets)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
  scheduler.step()
  running_loss /= len(train_loader)
  
  # gradient clipping
  nn.utils.clip_grad.clip_grad_value_(model.parameters(), 5)

  # validate
  model.train(False)
  running_vloss = 0.0
  for data in val_loader:
    inputs, targets = data[0].to(device), data[1].to(device)
    output, state = model(inputs, state)
    vloss = loss_fn(output, targets)
    running_vloss += vloss.item()
    
  running_vloss /= len(val_loader)

  print(f'Epoch {epoch + 1}/{epochs}: Training Loss = {running_loss}; Validation Loss = {running_vloss};')

I think the problem is that you clip the gradient after making the step with the optimizer.
The correct may be:

You may try, it is a hypothesis.

I had to detach the states of the rnn each iteration:

if isinstance(state, torch.Tensor):
    state = state.detach()
else:  # for LSTM
    state = tuple(s.detach() for s in state)

But the gradient clipping was also in the wrong place. It should be before the step, as @Matteo_Ciotola pointed out.

1 Like