i stumbled upon
this post and all is well, learning rate too low etc., but then i read:
I thus modified the pytorch implementation by moving the 3 lines: optimizer.zero_grad()
, loss.backward()
, optimizer.step()
inside the (training) batches loop instead of running them after the loop. As a result, this is the edited code that performs training and testing at each epoch:
# Iterate over epochs
for epoch in range(1, n_epochs+1):
train_loss = 0
model.train()
train_predictions = []
train_true_labels = []
# Iterate over training batches
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = Variable(inputs).to(device), Variable(labels).to(device)
optimizer.zero_grad() # set gradients to zero
preds = model(inputs)
preds.to(device)
# Compute the loss and accumulate it to print it afterwards
loss = loss_criterion(preds, labels)
train_loss += loss.detach()
pred_values, pred_encoded_labels = torch.max(preds.data, 1)
pred_encoded_labels = pred_encoded_labels.cpu().numpy()
train_predictions.extend(pred_encoded_labels)
train_true_labels.extend(labels)
loss.backward() # backpropagate and compute gradients
optimizer.step() # perform a parameter update
# Evaluate on development test
predictions = []
true_labels = []
dev_loss = 0
model.eval()
for i, (inputs, labels) in enumerate(dev_loader):
inputs, labels = Variable(inputs).to(device), Variable(labels).to(device)
preds = model(inputs)
preds.to(device)
loss = loss_criterion(preds, labels)
dev_loss += loss.detach()
pred_values, pred_encoded_labels = torch.max(preds.data, 1)
pred_encoded_labels = pred_encoded_labels.cpu().numpy()
predictions.extend(pred_encoded_labels)
true_labels.extend(labels)
I thus trained again the network using my default configuration (Adam, lr=0.001) and surprisingly I obtained a convergence at epoch 22 (see images below). I think the issue was there, do you agree? Do you have any additional advice? Thanks again!
what? what? so this person reordered optimizer.zero_grad() into the next loop iteration just after the datapoint has been loaded? how should this affect convergence at all? that can’t be real, right?