I’m trying to use an LSTM model to predict the right hand side of the system of ODE’s for a pendulum.
Problem: There are several strange spikes in my learning curves:
Turns out if I make
drop_last=True in my dataloader those spikes disappear.
Dropping the last mini batch seems to be a workaround. Is there a better way to address this problem? Is there a bug in my training loop that causes this?
Also, this appears to only happen when the sequence length is long (this one is 248*5.5=1364). Smaller sequence lengths do not have spikes like that, which also makes me think I have a bug somewhere.
def train_model(datatrain_IP, datatrain_OP, time_data, n_val, batch_size, model, nepoch, lrate, reg_factor, reg_type, device_name, epoch_in=0, train_loss_history=(), val_loss_history = (), drop_last=False): # define the loss_func & optimizer loss_func = nn.MSELoss() running_loss = 0.0 optimizer = optim.NAdam(model.parameters(), lr=lrate) # randomly split into train & validation data loaders train_dl, val_dl = prepare_data(datatrain_IP, datatrain_OP, time_data, n_val, batch_size, drop_last=drop_last) # enumerate epochs for epoch in range(epoch_in,nepoch): model.train() # enumerate mini batches for i, (inputs, targets, _) in enumerate(train_dl): inputs = inputs.to(device_name, non_blocking=False) targets = targets.to(device_name, non_blocking=False) # clear the gradients optimizer.zero_grad() # compute the model output yhat = model(inputs) # calculate loss loss = loss_func(yhat, targets) # find gradient of loss w.r.t tensors loss.backward() # update model weights optimizer.step() # calculate statistics running_loss += loss.item() # calculate average mse amongst all mini-batches in current epoch train_loss_history = train_loss_history + (running_loss/(i+1),) running_loss = 0.0 # Validation model.eval() # validation loss mse, actuals, predictions, val_time = evaluate_model(val_dl, model, device_name) val_loss_history = val_loss_history + (mse,) return train_loss_history, val_loss_history, optimizer, epoch, loss, val_data_actual, val_data_predictions, val_time