Wrongly placed `model.train()` but training/validation losses/metrics all improving

Hi, I accidentally forgot to put model.train() within each epoch, and yet in each epoch my evaluation function calls model.eval(). Weirdly, the train & evaluation losses and metrics are all improving as training proceeds (which makes me omit this issue for a month), but when I extract the final node embeddings (I’m training a GNN) they are definitely not the well-trained one. Of course, everything is back to normal when I put model.train() back to the right place. Am a little bit confused.

Demo code:

main function

# main()

model.train()  # this is put in the wrong place!

for i in range(num_epoch):
    train(model, optim, ...)

edge_all_loader, node_embeddings = generate_train_batch(xxx)

with torch.no_grad():  # generate final embeddings for all nodes
    out_all = model(edge_all_loader, node_embeddings, xxx)

while the train function looks like:

edge_train_loader, node_embeddings = generate_train_batch(xxx)
edge_val_loader, node_embeddings = generate_val_batch(xxx)  # the two initial node embeddings are the same

# model.train()  # It should've been here

node_embeddings_train = run_train(model, edge_train_loader, node_embeddings)  # another function to run batch training, where the loss is backproped and the model is (thought to be) updated

with torch.no_grad():
    out_val = model(edge_val_loader, node_embeddings, xxx)


model.train() will change the behavior of some layers, such as enabling dropout and using the batch stats in batchnorm layers while also updating the running stats. Depending on the model architecture and which layers are used, model.train()/.eval() might or might not have any effect (you should still use it in case you’ll add e.g. dropout layers later).

Thanks Piotr, you reminded me that dropout should have been causing this weird behavior. But it is still quite intriguing to see the link prediction task going so well without dropout while the embeddings are updated so badly. Such an interesting behavior…