Hi everyone. I’m trying to implement a video classification scheme, everything seems fine so far except one thing: exploding gradients in validation loop. I know it sounds strange because there’s not supposed to be gradients in the validation process, but that’s also what I don’t get. I’ve made sure to turn on eval() mode, and use torch.no_grad(), and somehow exploding gradients (with NaN outputs) still happens ONLY when there is a validation loop. I’ve tried commented out the validation code and the training code ran smoothly, so I figured something must be wrong with the validation code but I can’t put my hands on it. I’d really appreciate some help to point me in the right direction.
My code:
for epoch in range(params.getint('num_epochs')):
print('Starting epoch %i:' % (epoch + 1))
print('*********Training*********')
training_loss = 0
training_losses = []
training_progress = tqdm(enumerate(train_loader))
artnet.train()
for batch_index, (frames, label) in training_progress:
training_progress.set_description('Batch no. %i: ' % batch_index)
frames = frames.to(device)
label = label.to(device)
optimizer.zero_grad()
output = artnet.forward(frames)
loss = criterion(output, label)
training_loss += loss.item()
loss.backward()
optimizer.step()
else:
avg_loss = training_loss / len(train_loader)
training_losses.append(avg_loss)
print(f'Training loss: {avg_loss}')
print('*********Validating*********')
validating_loss = 0
validating_losses = []
validating_progress = tqdm(enumerate(validation_loader))
artnet.eval()
with torch.no_grad():
for batch_index, (frames, label) in validating_progress:
validating_progress.set_description('Batch no. %i: ' % batch_index)
frames = frames.to(device)
label = label.to(device)
output = artnet.forward(frames)
loss = criterion(output, label)
validating_loss += loss.item()
else:
avg_loss = validating_loss / len(validation_loader)
validating_losses.append(avg_loss)
print(f'Validating loss: {avg_loss}')
print('=============================================')
print('Epoch %i complete' % (epoch + 1))
if (epoch + 1) % params.getint('ckpt') == 0:
print('Saving checkpoint...' )
torch.save(artnet.state_dict(), os.path.join(params['ckpt_path'], 'arnet_%i' % (epoch + 1)))
# Update LR
scheduler.step()
print('Training complete, saving final model....')
torch.save(artnet.state_dict(), os.path.join(params['ckpt_path'], 'arnet_final'))
return training_losses, validating_losses