Decreasing weights without optimization

I’m trying to fine-tune a pretrained resnet model via a triplet loss. Due to the nature of the loss function, sometimes I make multiple forward passes through the model without ever making a backwards pass or stepping the optimizer. It seems that the norm of the final layer decreases even when I do not explicitly optimize.

I have created a small toy example below to illustrate what I’m talking about. I run an image forward through the model multiple times and then pass the same image as a “validation” image and calculate the final layer norm. Maybe this behavior is a result of the BatchNorm layers? If so, I would love to how those layers create this behavior.

import io

import numpy as np
from PIL import Image
import requests

import torch
from torch.autograd import Variable
from torchvision.models import resnet50
import torchvision.transforms as transforms

# Create finetuning model.
model = resnet50(pretrained=True)
for name, child in model.named_children():
    if name != 'fc':
        for p in child.parameters():
            p.requires_grad = False

            
optimizer = torch.optim.SGD(filter(lambda x: x.requires_grad, 
                                   model.parameters()),
                            lr=1e-3)

# Download a random image and transform for input
# to resnet.
url = ('https://raw.githubusercontent.com/pytorch/'
       'pytorch/master/docs/source/_static/img/'
       'pytorch-logo-dark.png')
response = requests.get(url)
img = Image.open(io.BytesIO(response.content))
transform = transforms.Compose([
    transforms.Scale(224),
    transforms.CenterCrop(224),
    transforms.ToTensor()])
img = transform(img)[:3, :, :]

# "Training" and "Validation" data are the exact
# same image.
train = img.unsqueeze(0)
val = Variable(img.unsqueeze(0), volatile=True)

# Run a single forward pass and measure
# norm of validation embedding. No backwards
# passes, no optimizer steps.
for epoch in range(20):
    
    model.train()
    optimizer.zero_grad()
    train_emb = model(Variable(train))
    
    model.eval()
    val_emb = model(val)
    print(val_emb.norm(2).data[0])

Sounds like BatchNorm–when you call .train() on your net, you put all of its child modules into training mode, including the BatchNorm modules. When BatchNorm is in training mode, every time you run a forward pass, it is updating its “running mean” and “running variance” parameters on its own (i.e. no interaction with SGD or any user-defined optimizer) which will definitely have an effect on your observed output. I’m not sure if this holds true when the Variable being passed through is volatile, but given what you’re seeing I’m going to guess that it does.

If you don’t want these updates to happen, put the BatchNorm modules in inference mode with model.eval(), or just set the specific BatchNorm modules to inference mode if you need the other modules in training mode (i.e. loop through all modules and have a conditional that checks if a module is nn.BatchNorm2d, and calls .eval() on it if it is).

Or you could probably set the momentum term of each BatchNorm module to 0–IIRC, that’s the term that controls the update rate of the running means and variances, so if you still want to use the per-batch means and variances but you don’t want to update the running statistics, setting momentum to 0 should allow you to do that.

Gotcha, thanks so much for the explanation. I guess it still weirds me out that the running mean and variance should be continually changing when feeding the same data into the layer, but I don’t have a great handle on how that calculation is performed (and got a bit lost trying to delve into the C code that actually implements it).

Nonetheless, thanks for the explanation and the recommended ways to deal with it!

passing a single image through a BatchNorm network is a terrible idea, especially in training mode.
Your batch statistics end up being just image channel statistics.

What’s happening with running_mean/running_std?

During training, this layer keeps a running estimate of its computed mean and variance. The running sum is kept with a default momentum of 0.1.

i.e. it’s updated as: running_mean = 0.9 * running_mean + 0.1 * current_mean

I’d suggest that you iterate over your modules and set all the BatchNorm layers to .eval(), like this:

m.apply(lambda x: x.eval() if 'BatchNorm' in str(type(x)) else False)

Passing a single instance to BatchNorm was just for demonstration purposes in my toy example. I’m not actually doing this for the real problem I was trying to solve.

Thanks for the update calculation! It looks like running_mean is initialized at zero. So, even if we have a static current_mean does not change between updates, running_mean will have to slowly update from the initial zero to the static current_mean value.

Anecdotal and orthogonal, but I’ve successfully trained object detection convnets (homebrew that’s somewhere between OverFeat and ResNet Faster R-CNN) with single-image-batches and batchnorm at every layer. Not sure why it works (I would expect it to fail) but it trains just fine.

1 Like