[BUG] Weird behavior between evaluation and training mode

Actually, it is an ongoing project in which I don’t have such a permission to share some part of it. If it is possible, please tell me what do you want to know (intermediate output or whatever else), I will share with you! :slight_smile:

Yeah it should update them, even if they are excluded from what optimizer can see. But it’s worth mentioning that DataParallel wrapper doesn’t properly update running statistics. https://github.com/pytorch/pytorch/issues/1051

1 Like

Could you tell us in detail how you froze the pretrained part?

yes sure. Using these following functions, I set the requires_grad to False or True based on my pre-trained section.

def release_weight(model):
    for param in model.parameters():
        param.requires_grad = True
    return model
def freeze_weight(model):
    for param in model.stage1.parameters():
        param.requires_grad = False
    for param in model.stage4.parameters():
        param.requires_grad = False
    for param in model.stage5.parameters():
        param.requires_grad = False       
    for param in model.stage5_1.parameters():
        param.requires_grad = False
    for param in model.parallel1.parameters():
        param.requires_grad = False
    for param in model.parallel2.parameters():
        param.requires_grad = False
    for param in model.stage7.parameters():
        param.requires_grad = False
    for param in model.final.parameters():
        param.requires_grad = False
    for param in model.d.parameters():
        param.requires_grad = False
    return model

Actually, when I instantiate the model, I call above two functions:

model = Model(pretrained_path, 425)
model = release_weight(model)
model = freeze_weight(model)

Then I select the parameters which has requires_gard=True using following commands and give them to optimizer:

parameters = itertools.filterfalse(lambda p: not p.requires_grad, model.parameters())
optimizer = optim.SGD(parameters, lr=baseLR, weight_decay=0.001, momentum=0.9)

I have the same problem. I’m not sure if it is for using DataParalell or what

@SimonW That’s good to know! Thanks for the link!

@mderakhshani Ok, I understand. Do you see the same error when you load both models from the checkpoints, set them to evaluation mode and pass a random sample through them?
Do you use any “special” layers, e.g. with conditions etc.?

@mfayyaz Are you able to share a small code snippet reproducing this issue?

Wow! It’s amazing!
Test your saving procedure and load your parameter again. I have some issue with the similar problem but it solved by checking save and load procedure.
Also, your definition of new layers would cause this problem.
I hope it gets resolved soon…

1 Like

Let me give you all of my possible permutation of networks. :smiley:

Rule1: pre-trained net: is a network without any training, just loaded the the pre-trained weights and initialized the extra layers.

Rule2: checkpoint: is a network initialized by the pre-trained weight and learnt on my dataset (no parameter updating on my pre-trained section - frozen part).

  1. pre-trained and pre-trained in training mode: the outputs are equal.
  2. pre-trained and pre-trained in evaluation mode: the outputs are equal.
  3. pre-trained and checkpoint0 in training mode: the outputs are equal.
  4. pre-trained and checkpoint0 in evaluation mode: the outputs are not equal.
  5. checkpoint0 and checkpoint1 in training mode: the outputs are equal.
  6. checkpoint0 and checkpoint1 in evaluation mode: the outputs are equal.

I did not have any special layer in my net. All of them are the regular layers such as Conv2d, Maxpool, BatchNorm and nothing else.

Hello @SimonW, @mfayyaz, @ptrblck, @Khalooei. I have prepared two scripts. one of them is used for training and the other for comparing the output. Here is my main file to overfit on a sample based on resnet18 pre-trained weight.

import torch
import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable as V
import torch.optim as optim


class MyModel(nn.Module):
	"""docstring for MyModel"""
	def __init__(self, resnet18):
		super(MyModel, self).__init__()
		self.features = nn.Sequential(*list(resnet18.children())[:-1])
		self.f1 = nn.Linear(2048, 4)

	def forward(self, input):
		out = self.features(input).detach()
		return self.f1(out.view(out.size(0),-1))
		
def release_weight(model):
    for param in model.parameters():
        param.requires_grad = True
    return model

def freeze_weight(model):
    for param in model.features.parameters():
        param.requires_grad = False
    return model    

resnet18 = models.resnet18(pretrained = True)
model = MyModel(resnet18)
model = release_weight(model)
model = freeze_weight(model)

parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.SGD(parameters, lr=1e-3, weight_decay=0.001, momentum=0.9)

inp = V(torch.randn(1,3,256,256), requires_grad = False)
target = V(torch.randn(1,4))

for i in range(2):
	for j in range(100):
		optimizer.zero_grad()
		out = model(inp)
		loss = ((out - target) ** 2).mean()
		loss.backward()
		optimizer.step()
		print(loss.data[0])
	torch.save(model.state_dict(), './simpleModel{}.pth'.format(i))

and then here is the comparing code:

import torch
import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable as V

class MyModel(nn.Module):
    """docstring for MyModel"""
    def __init__(self, resnet18):
        super(MyModel, self).__init__()
        self.features = nn.Sequential(*list(resnet18.children())[:-1])
        self.f1 = nn.Linear(2048, 4)

    def forward(self, input):
        out1 = self.features(input)
        return out1, self.f1(out1.view(out1.size(0),-1))

ckpt0 = 'simpleModel0.pth'
ckpt1 = 'simpleModel1.pth'

resnet18 = models.resnet18(pretrained = True)
model0 = MyModel(resnet18)
model0.load_state_dict(torch.load(ckpt0), strict=True)
model0.train(False)


resnet18 = models.resnet18(pretrained = True)
model1 = MyModel(resnet18)
# model1.load_state_dict(torch.load(ckpt1))
model1.train(False)

model0_name = []
for name, param in model0.named_parameters():
    model0_name.append(name)

model1_name = []
for name, param in model1.named_parameters():
	model1_name.append(name)

diff = []
for p1, p2 in zip(model0.parameters(), model1.parameters()):
	diff.append((p1.data - p2.data).sum())

for n in zip(model0_name, model1_name, diff):
	print(n[0], n[1], "{0:.8f}".format(n[2]))

model0_running_variance = []
model0_running_mean = []
for module in model0.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model0_running_mean.append(module.running_mean)
    	model0_running_variance.append(module.running_var)

model1_running_variance = []
model1_running_mean = []
for module in model1.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model1_running_mean.append(module.running_mean)
    	model1_running_variance.append(module.running_var)

print("running_mean difference")
for m1, m2 in zip(model0_running_mean, model1_running_mean):
	print("{0:.8f}".format((m1-m2).sum()))

print("running_var difference")
for m1, m2 in zip(model0_running_variance, model1_running_variance):
	print("{0:.8f}".format((m1-m2).sum()))

inp = V(torch.randn(1,3,256,256))
pretrained_head0, out0 = model0(inp)
pretrained_head1, out1 = model1(inp)


diffpretrained_head = (pretrained_head1 - pretrained_head0).data.abs().sum()
diffout = (out1 - out0).data.abs().sum()

print(diffpretrained_head, diffout)

And again, these scripts proved my claim. Could you please check them up on your system?

Hi, I had the same problem a couple of months ago, I thought there is some thing wrong in my code!

1 Like

I debugged your code a bit and it seems that the BatchNorm layers differ.
You can’t see it, since you have a typo in saving the running_mean and running_var:

model0_running_variance = []
model0_running_mean = []
for module in model0.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model0_running_mean.append(module.running_mean)
    	model0_running_variance.append(module.running_var)

model1_running_variance = []
model1_running_mean = []
for module in model0.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model1_running_mean.append(module.running_mean)
    	model1_running_variance.append(module.running_var)

In both loops you iterate model0.
I will check, why they differ.

1 Like

The problem is that your BN layers differ.
I used the following code to solve the problem (just override the train function of your model):

    def train(self, mode=True, freeze_bn=False, freeze_bn_affine=False):
        
        super(MyModel, self).train(mode)
        if freeze_bn:
            for m in self.modules():
                if isinstance(m, nn.BatchNorm2d):
                    '''Freezing Mean/Var of BatchNorm2D'''
                    m.eval()
                    if freeze_bn_affine:
                        '''Freezing Weight/Bias of BatchNorm2D'''
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False
5 Likes

Thanks. I have edited my reply! Sorry.

I’ve checked the training part and the running_* stats are being updated.
You have to set the BatchNorm layers to evaluation mode.

Add this to your training:

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

EDIT: Too late, @mfayyaz was faster! :wink: Good catch!

3 Likes

Wooow. Thanks @mfayyaz. It was solved by your snippet.

Thanks :grinning::wink:

Hi Mohammad

output1:

output2:

As @mfayyaz and @ptrblck said, the issue would occur in BN section.

have a nice time :slight_smile:

So what is the problem? You mean that .train(False) does not work properly?

You have never set it to .train(False) in your training script. You just disabled the gradients.
That’s why BatchNorm still updated its running_* stats.

1 Like

Okay. I have got it. Thank you