LSTM Training: variable required for grad computation was modified by inplace operation

Well, the error is in the title. For some reason there is some operation that breaks the gradient, and i have no idea where it is

Here is the code:

Model: (Audio decoders and encoders are conv1d encoders that compress input 4 times, and are not the source of the problem, they were tested separately and they worked)

class AudioGenerator(nn.Module):
  def __init__(self,bs):
    super(AudioGenerator,self).__init__()

    self.encoder = AudioEncoder()
    self.decoder = AudioDecoder()
    self.bs = bs #batch size

    self.lstm = nn.LSTM(4000,4000, batch_first = True)

    self.linear1 = nn.Linear(4000,4000)

    self.reset_hidden(bs)

    self.tanh = nn.Tanh()
    self.relu = nn.LeakyReLU()

  def forward(self,inp):
    
    out1 = self.encoder(inp)
    _,(self.hidden,self.state) = self.lstm(out1, (self.hidden,self.state))

    out = self.hidden.clone()

    out = torch.squeeze(out,0)

    out = self.linear1(out)
    out = self.tanh(out)

    out = out.reshape(out.shape[0], 1,out.shape[1])
    out = self.decoder(out)
    return torch.squeeze(out,1)

  def reset_hidden(self, bs = None):
    if bs is None:
      self.hidden = torch.zeros(1,self.bs,4000).to(DEVICE)
      self.state = torch.zeros(1,self.bs,4000).to(DEVICE)
    else:
      self.hidden = torch.zeros(1,bs,4000).to(DEVICE)
      self.state = torch.zeros(1,bs,4000).to(DEVICE)

def print_inline(message): #a helper function
  sys.stdout.write("\r" + message)

Training loop: (dataset returns a tensor of shape (seq_length) where seq_length is 16000*num_seconds, data loader assembles them into shape (batch_size,1,seq_length), after that they get separated into seconds of length 16000, and processed)

losses = [[]]
epochs = 100
batchnum = 1
batches = len(train_loader)
try:
  for epoch in range(0,epochs):
    for seqs in train_loader:
      seqs = seqs.to(DEVICE)
      #print("Seqs shape is: {}".format(seqs.shape))
      start = time.time()

      seqs = torch.split(seqs,16000,1) #extract seconds from input sequence
      #Reset hidden and grad
      model.reset_hidden()
      optimizer.zero_grad()

      for i in range(len(seqs)-1): #edited mistake here, was compairing network output ot its input, fixed that
        output = model(torch.unsqueeze(seqs[i],1))
        output_loss = loss(output,seqs[i+1])
        output_loss.backward(retain_graph=True)
        optimizer.step()
        losses[0].append(output_loss.item())

      print_inline("Epoch: {}/{}  Batch: {}/{} Loss: {:.4f}; time/batch:{:.4f} seconds".format(epoch,epochs,batchnum,batches,
                                                                                               output_loss.item(), time.time() - start))   
      if batchnum % 25 == 0:
        print('New checkpoint!')
        torch.save({'optimizer_state_dict':generator.state_dict(), 
                    'model_state_dict':discriminator.state_dict()
                    },save_dir)
        plot_losses(losses)
      batchnum += 1  

    batchnum = 0
    plot_losses(losses)
except KeyboardInterrupt:  #if interrupted show statistics
    plot_losses(losses)

plot_losses(losses)

I think it is weird that your input second is also your target. The best your network can do is to be the identity function.

Thanks for reply, fixed that

1 Like