Backward takes a long time while GPU usage is zero. Train one LSTM model with another LSTM model

Hi all,

I’m training an LSTM model with the following code. First I will generate a time series from my LSTM model inverse_lstm. Then the generated series is fed to another pre-trained LSTM model test_lstm, which is in evaluation mode. The idea is to train inverse_lstm with the help of test_lstm. The output from inverse_lstm is the input to test_lstm and the output of test_lstm is used in the loss function.

# generate input from inverse_lstm. 
# Here the data loader is not shuffled to follow chronological order. 

for i, data in enumerate(loader):
    x_d, x_s, x_one_hot, y = data
    x_d, x_s = x_d.to(device), x_s.to(device)
    x_one_hot, y = x_one_hot.to(device), y.to(device)
    x_ = inverse_lstm(x_d=x_d, x_s=x_s, x_one_hot=x_one_hot, y_true=y)[0] 
    x_ = x_[:, 0, :]  
    x_ = x_.cpu().data
    x_generated.append(x_)  

# The generated data will be fed into a fixed LSTM model. 

x_generated = torch.cat(x_generated, dim=0)  # concat along batch. 
x_generated = x_generated.requires_grad_(True)
ds_gen = Dataset_Precip(ds=ds, x_precip=x_generated)
gen_loader = DataLoader(ds_gen, batch_size=128, shuffle=True)

# Train inverse_lstm with test_lstm. 
# The gradient needs to trace back to inverse_lstm. 

for data in tqdm(gen_loader):
    x_d, x_s, x_one_hot, y = data
    x_d, x_s = x_d.to(device), x_s.to(device)
    x_one_hot, y = x_one_hot.to(device), y.to(device)
    y_hat = test_lstm(x_d=x_d, x_s=x_s, x_one_hot=x_one_hot)[0]
    y_hat_sub = y_hat[:, -1:, :]
    y_sub = y[:, -1:, :]
    optimizer.zero_grad()
    loss = mse_loss(y_hat_sub, y_sub)
    loss.backward(retain_graph=True)
    optimizer.step()

The training process takes a long time, especially loss.backward(). It takes around 2 minutes for one batch. Also by monitoring CPU and GPU, the CPU usage is pretty high and GPU usage is zero during the backward process. Is it because that the gradient is hard to calculate from the loss to the first lstm model? Or is there anything wrong with my code?

Thank you for any suggestion and help!