Getting different feature vectors from frozen layers after training

Hi, I’m currently trying to train a classification model which besides the scores, should also return the image’s feature vector. In order to do this, I used a pre-trained MobileNetV2, with a new classification layer at the end (with only the classes I need). The objective is to freeze the feature extraction layers and only train the new classification layer. In order to do this, I implemented the model as follows:

class CustomClassifier(torch.nn.Module):
    def __init__(self, out_features):
        super(CustomClassifier, self).__init__()
        self.model = models.mobilenet_v2(pretrained=True, progress=False)
        classifier = nn.Sequential(nn.Dropout(p=0.2, inplace=False),
                     nn.Linear(in_features=1280, out_features=out_features),
                     nn.LogSoftmax(dim=1))
        self.model.classifier = classifier
        # freeze features parameters
        for p in self.model.features.parameters(): 
            p.requires_grad = False
           
    def forward(self, x):
        return self.model(x), self.model.features(x)

Then I fit my model, just like this:

def fit(epochs, model, opt, loss_fn, train_dl, valid_dl, metric):
    # *** MOVE MODEL TO GPU ***
    model = model.to(device)
    
    # PyTorch LR scheduler performs learning rate annealing to boost training performance (optional)
    lr_sched = optim.lr_scheduler.CosineAnnealingLR(opt, len(train_dl)*epochs, 
                                                    eta_min=opt.defaults['lr']/1e6)
    train_start = datetime.now()
    train_losses = []
    val_losses = []
    train_metrics = []
    val_metrics = []
    
    pbar = tqdm(range(1, epochs+1))
    
    # *** TRAINING LOOP ***
    for epoch in pbar:
        train_loss, train_metric, valid_loss, valid_metric = 0., 0., 0., 0.
        
        # *** TRAIN ***
        model.train()
        for xb, yb in train_dl:                     # iterate every batch
            xb, yb = xb.to(device), yb.to(device)   # move features and target to GPU
            output, _ = model(xb)                      # predict
            loss = loss_fn(output, yb)              # compute loss
            loss.backward()                         # backprop
            opt.step()                              # adjust weights
            opt.zero_grad()
            lr_sched.step()                              
            train_loss += loss.item() * len(xb)
            train_metric += metric(output, yb) * len(xb)
            
        # *** VALIDATION ***
        model.eval()
        for xb, yb in valid_dl:
            xb, yb = xb.to(device), yb.to(device)
            with torch.no_grad(): output, _ = model(xb)
            loss = loss_fn(output, yb)
            valid_loss += loss.item() * len(xb)
            valid_metric += metric(output, yb) * len(xb)
   
    # compute metrics and print
        train_loss /= len(train_dl.dataset)
        train_losses.append(train_loss)
        train_metric /= len(train_dl.dataset)
        train_metrics.append(train_metric)
        valid_loss /= len(valid_dl.dataset)
        val_losses.append(valid_loss)
        valid_metric /= len(valid_dl.dataset)
        val_metrics.append(valid_metric)
        pbar.set_description(
            '[{}] loss: train={:.3f}, val={:.3f} -- metric: train={:.3f}, val={:.3f}'.format(
                epoch, train_loss, valid_loss, train_metric, valid_metric
            )
        )
    train_elapsed = datetime.now() - train_start
    print(f'Total time: {train_elapsed.seconds} seg.')
    return train_losses, train_metrics, val_losses, val_metrics

I’m using these hyperparameters:

model = CustomClassifier(len(train_ds.classes))
loss_fn = nn.NLLLoss()
metric = accuracy
opt = optim.Adam(model.model.classifier.parameters(), lr=3e-3)

Finally I call fit method:

train_losses, train_metrics, val_losses, val_metrics = fit(3, model, opt, loss_fn, train_dl, valid_dl, metric)

Each time I train the model, I notice the extracted features are different for the exact same image. But this is not supposed to happen as those layers are frozen from the beggining, and the same weights are being used when I instance the model (mobilenet pretrained on ImageNet).

Any ideas for what might be causing this inconsistency?

Thanks!!

1 Like

Hmm, I’m not able to replicate this. The following prints True for me -

img = next(iter(test_dl))[0]

fit(1, model, opt, loss_fn, train_dl, valid_dl, metric)
_, feature1 = model(img.cuda())
fit(1, model, opt, loss_fn, train_dl, valid_dl, metric)
_, feature2 = model(img.cuda())

print(torch.allclose(feature1, feature2))

Hi @soulitzer, first of all thanks for help me!

Mmmm, I think the correct form to replicate this is trying to train two different model versions. So I train a v1 of my model and deploy it. After some time I train another version with more data and deploy it again, but I would like that the feature vector would be the same for the exact same image. I do this example:

train_losses, train_metrics, val_losses, val_metrics = fit(3, model, opt, loss_fn, train_dl, valid_dl, metric)

train_losses_1, train_metrics_1, val_losses_1, val_metrics_1 = fit(5, model_1, opt_1, loss_fn_1, train_dl, valid_dl, metric_1)

And then when I do the check:

img = next(iter(test_dl))[0]
_, feature1 = model(img)
_, feature2 = model_1(img)
print(torch.allclose(feature1, feature2)) #<--- False 

Hi, @Francisco_Yackel :slight_smile: As far as I understand, the most probable explanation of this difference is BatchNorm layers of your backbone. Even then you make requieres_grad = False on backbone layers, forward pass through the network will update running statistics of BatchNorm layers. So, you have to apply eval() method on the BatchNorm layers as well, this will make sure running statistics will not update during training process. This thread could be off great help to you if you want more clarification on the topic.

I would also like to notice, when your fine-tune pretrained model you will probably have a bit different dataset and perhaps would like to actually update those running statistics, so BatchNorm layers could serve as intended (normalizing intermediate representations) for your data as well. Anyway, you can always try both approaches and see what works best for your task at hand.

Cheers :slight_smile:

1 Like

Hi @Alexey_Demyanchuk, you are right!!
Thanks for help me!
Adding this into my fit method between model.train() and the loop:

# *** TRAIN ***
model.train()
for module in model.modules():
    if isinstance(module, nn.BatchNorm2d):
        module.eval()
for xb, yb in train_dl:                     # iterate every batch

I can fix my inconsistency.
Thanks again :smiley:

Great :slight_smile: Nice to hear!