PyTorch Last & Best Models have the exact weights!

Hi everyone,

I am building a PyTorch training function where I am intending to save the best model and last model. However, both of the models result in the exact weight although I am addining a condition of validation loss. Below is my code:

model = NeuralNet(in_dimension=2, out_dimension=1)
num_epoch = 500
optimizer = nn.optim.Adam(…)
criterion = nn.torch.BCELoss(…)
train_loader= DataLoader(…)
val_loader = DataLoader(…)

def a_training(num_epoch, model, optimizer, criterion, train_loader, val_loader):
best_model = NeuralNet(in_dimension=2, out_dimension=1) #initializing best model
best_val_loss = None
for epoch in range(num_epoch):
train_loss = train_epoch(epoch, optimizer, criterion, model, train_loader)
val_loss = validate_epoch(epoch, criterion, model, val_loader)
if epoch ==0:
best_val_loss = val_loss
torch.save(model, ‘/content/_first_model.pt’)
else:
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = model
torch.save(best_model, ‘/content/_best_model.pt’)
torch.save(model, ‘/content/_last_model.pt’)
return None

then I am running the function a_training(…)
and later I am loading the models, checking their weights but the last and best are the same !

#Verifying that last and best models have different weights in the final layer

_first_model = torch.load(‘/content/_first_model.pt’)
_best_model = torch.load(‘/content/_best_model.pt’)
_last_model = torch.load(‘/content/_last_model.pt’)

print(“First model”,torch.sum(syn1_first_model.fc2.weight.data))
print(“Best model”,torch.sum(syn1_best_model.fc2.weight.data))
print(“last model”,torch.sum(syn1_last_model.fc2.weight.data))

and I am getting :slight_smile:

First model tensor(-0.4032)
Best model tensor(-4.8789)
last model tensor(-4.8789)

Couldn’t the last model also be the best model or did you verify that the condition is not met in the last training iteration?
You might also want to calculate the abs().max() error between some parameters as a pure visual comparison might not be sufficient and lower decimals might have been changed.

I tried different epochs, from 100 to 500 and they always result in the same fc2 layer weight. Yes, you are right as I need to verify that the condition is working properly.