[BUG] Weird behavior between evaluation and training mode

Hello There,

I have got two networks. The first one is a network initialized with a pre-trained model plus some extra layers defined by me and the second one is the same network but trained for one epoch. It is worthwhile to mention that the pre-trained section of both networks was frozen in order to getting no parameter updating and my optimizer just trains the extra layers. For being sure about no parameter updating on my pre-trained section of both network, I have compared the parameters of them using the following code.

model0 = Model(model_path, 425)
model0 = nn.DataParallel(model0)
model0.load_state_dict(torch.load(ckpt0))
model0.train(True)

model1 = Model(model_path, 425)
model1 = nn.DataParallel(model1)
model1.train(True)

model0_name = [] # layers' name
for name, param in model0.named_parameters():
    model0_name.append(name)

model1_name = [] # layers' name
for name, param in model1.named_parameters():
    model1_name.append(name)

diff = []
for p1, p2 in zip(model0.parameters(), model1.parameters()):
	diff.append((p1.data - p2.data).sum())

for n in zip(model0_name, model1_name, diff):
	print(n[0], n[1], "{0:.8f}".format(n[2]))

and here is the output:

module.stage1.0.weight module.stage1.0.weight 0.00000000
module.stage1.1.weight module.stage1.1.weight 0.00000000
module.stage1.1.bias module.stage1.1.bias 0.00000000
module.stage1.4.weight module.stage1.4.weight 0.00000000
module.stage1.5.weight module.stage1.5.weight 0.00000000
module.stage1.5.bias module.stage1.5.bias 0.00000000
module.stage1.8.weight module.stage1.8.weight 0.00000000
module.stage1.9.weight module.stage1.9.weight 0.00000000
module.stage1.9.bias module.stage1.9.bias 0.00000000
module.stage1.11.weight module.stage1.11.weight 0.00000000
module.stage1.12.weight module.stage1.12.weight 0.00000000
module.stage1.12.bias module.stage1.12.bias 0.00000000
module.stage1.14.weight module.stage1.14.weight 0.00000000
module.stage1.15.weight module.stage1.15.weight 0.00000000
module.stage1.15.bias module.stage1.15.bias 0.00000000
module.stage4.0.weight module.stage4.0.weight 0.00000000
module.stage4.1.weight module.stage4.1.weight 0.00000000
module.stage4.1.bias module.stage4.1.bias 0.00000000
module.stage4.3.weight module.stage4.3.weight 0.00000000
module.stage4.4.weight module.stage4.4.weight 0.00000000
module.stage4.4.bias module.stage4.4.bias 0.00000000
module.stage4.6.weight module.stage4.6.weight 0.00000000
module.stage4.7.weight module.stage4.7.weight 0.00000000
module.stage4.7.bias module.stage4.7.bias 0.00000000
module.stage5.0.weight module.stage5.0.weight 0.00000000
module.stage5.1.weight module.stage5.1.weight 0.00000000
module.stage5.1.bias module.stage5.1.bias 0.00000000
module.stage5.3.weight module.stage5.3.weight 0.00000000
module.stage5.4.weight module.stage5.4.weight 0.00000000
module.stage5.4.bias module.stage5.4.bias 0.00000000
module.stage5.6.weight module.stage5.6.weight 0.00000000
module.stage5.7.weight module.stage5.7.weight 0.00000000
module.stage5.7.bias module.stage5.7.bias 0.00000000
module.stage5.9.weight module.stage5.9.weight 0.00000000
module.stage5.10.weight module.stage5.10.weight 0.00000000
module.stage5.10.bias module.stage5.10.bias 0.00000000
module.stage5.12.weight module.stage5.12.weight 0.00000000
module.stage5.13.weight module.stage5.13.weight 0.00000000
module.stage5.13.bias module.stage5.13.bias 0.00000000
module.parallel1.0.weight module.parallel1.0.weight 0.00000000
module.parallel1.1.weight module.parallel1.1.weight 0.00000000
module.parallel1.1.bias module.parallel1.1.bias 0.00000000
module.parallel1.3.weight module.parallel1.3.weight 0.00000000
module.parallel1.4.weight module.parallel1.4.weight 0.00000000
module.parallel1.4.bias module.parallel1.4.bias 0.00000000
module.parallel1.6.weight module.parallel1.6.weight 0.00000000
module.parallel1.7.weight module.parallel1.7.weight 0.00000000
module.parallel1.7.bias module.parallel1.7.bias 0.00000000
module.parallel1.9.weight module.parallel1.9.weight 0.00000000
module.parallel1.10.weight module.parallel1.10.weight 0.00000000
module.parallel1.10.bias module.parallel1.10.bias 0.00000000
module.parallel1.12.weight module.parallel1.12.weight 0.00000000
module.parallel1.13.weight module.parallel1.13.weight 0.00000000
module.parallel1.13.bias module.parallel1.13.bias 0.00000000
module.parallel1.15.weight module.parallel1.15.weight 0.00000000
module.parallel1.16.weight module.parallel1.16.weight 0.00000000
module.parallel1.16.bias module.parallel1.16.bias 0.00000000
module.parallel1.18.weight module.parallel1.18.weight 0.00000000
module.parallel1.19.weight module.parallel1.19.weight 0.00000000
module.parallel1.19.bias module.parallel1.19.bias 0.00000000
module.parallel2.0.weight module.parallel2.0.weight 0.00000000
module.parallel2.1.weight module.parallel2.1.weight 0.00000000
module.parallel2.1.bias module.parallel2.1.bias 0.00000000
module.stage7.0.weight module.stage7.0.weight 0.00000000
module.stage7.1.weight module.stage7.1.weight 0.00000000
module.stage7.1.bias module.stage7.1.bias 0.00000000
module.final.0.weight module.final.0.weight 0.00000000
module.extra1.conv_l1.weight module.extra1.conv_l1.weight 1.50025647
module.extra1.conv_l1.bias module.extra1.conv_l1.bias -0.37570643
module.extra1.conv_l2.weight module.extra1.conv_l2.weight 27.04075950
module.extra1.conv_l2.bias module.extra1.conv_l2.bias 0.08830106
module.extra1.conv_r1.weight module.extra1.conv_r1.weight 29.77446331
module.extra1.conv_r1.bias module.extra1.conv_r1.bias -0.20370359
module.extra1.conv_r2.weight module.extra1.conv_r2.weight -12.84376761
module.extra1.conv_r2.bias module.extra1.conv_r2.bias -0.37359143
module.extra2.conv_l1.weight module.extra2.conv_l1.weight -10.23451808
module.extra2.conv_l1.bias module.extra2.conv_l1.bias 0.01333411
module.extra2.conv_l2.weight module.extra2.conv_l2.weight 11.42386136
module.extra2.conv_l2.bias module.extra2.conv_l2.bias 0.10939794
module.extra2.conv_r1.weight module.extra2.conv_r1.weight 1.97030002
module.extra2.conv_r1.bias module.extra2.conv_r1.bias 0.05921281
module.extra2.conv_r2.weight module.extra2.conv_r2.weight 13.66498772
module.extra2.conv_r2.bias module.extra2.conv_r2.bias 0.14407422
module.extra3.conv_l1.weight module.extra3.conv_l1.weight 1.52511122
module.extra3.conv_l1.bias module.extra3.conv_l1.bias 0.18012815
module.extra3.conv_l2.weight module.extra3.conv_l2.weight 20.19549281
module.extra3.conv_l2.bias module.extra3.conv_l2.bias -0.76473302
module.extra3.conv_r1.weight module.extra3.conv_r1.weight -2.59647552
module.extra3.conv_r1.bias module.extra3.conv_r1.bias 0.14506025
module.extra3.conv_r2.weight module.extra3.conv_r2.weight 11.67924830
module.extra3.conv_r2.bias module.extra3.conv_r2.bias -0.00651512
module.extra4.conv_l1.weight module.extra4.conv_l1.weight 14.54665439
module.extra4.conv_l1.bias module.extra4.conv_l1.bias 0.08106837
module.extra4.conv_l2.weight module.extra4.conv_l2.weight 16.46649296
module.extra4.conv_l2.bias module.extra4.conv_l2.bias -0.26476345
module.extra4.conv_r1.weight module.extra4.conv_r1.weight -12.99556065
module.extra4.conv_r1.bias module.extra4.conv_r1.bias -0.05485360
module.extra4.conv_r2.weight module.extra4.conv_r2.weight -2.79881258
module.extra4.conv_r2.bias module.extra4.conv_r2.bias -1.07567936
module.extra_1.bn.weight module.extra_1.bn.weight -152.28462493
module.extra_1.bn.bias module.extra_1.bn.bias 0.31378557
module.extra_1.conv1.weight module.extra_1.conv1.weight 2.76232860
module.extra_1.conv1.bias module.extra_1.conv1.bias -0.00553248
module.extra_1.conv2.weight module.extra_1.conv2.weight -100.39555516
module.extra_1.conv2.bias module.extra_1.conv2.bias 0.00963779
module.extra_2.bn.weight module.extra_2.bn.weight -144.90659545
module.extra_2.bn.bias module.extra_2.bn.bias -0.43241561
module.extra_2.conv1.weight module.extra_2.conv1.weight -18.49401752
module.extra_2.conv1.bias module.extra_2.conv1.bias -0.03962684
module.extra_2.conv2.weight module.extra_2.conv2.weight -98.76576164
module.extra_2.conv2.bias module.extra_2.conv2.bias -0.07895776
module.extra_3.bn.weight module.extra_3.bn.weight -137.74657961
module.extra_3.bn.bias module.extra_3.bn.bias -1.83718258
module.extra_3.conv1.weight module.extra_3.conv1.weight -6.63687622
module.extra_3.conv1.bias module.extra_3.conv1.bias 0.16047683
module.extra_3.conv2.weight module.extra_3.conv2.weight -64.03853174
module.extra_3.conv2.bias module.extra_3.conv2.bias 0.37029462
module.extra_4.bn.weight module.extra_4.bn.weight -150.30557569
module.extra_4.bn.bias module.extra_4.bn.bias -0.88545457
module.extra_4.conv1.weight module.extra_4.conv1.weight 8.52840125
module.extra_4.conv1.bias module.extra_4.conv1.bias -0.16135700
module.extra_4.conv2.weight module.extra_4.conv2.weight 39.86314841
module.extra_4.conv2.bias module.extra_4.conv2.bias -0.30344061
module.extra_5.bn.weight module.extra_5.bn.weight -153.87934927
module.extra_5.bn.bias module.extra_5.bn.bias -0.57383157
module.extra_5.conv1.weight module.extra_5.conv1.weight -1.10513980
module.extra_5.conv1.bias module.extra_5.conv1.bias -0.10425282
module.extra_5.conv2.weight module.extra_5.conv2.weight 36.12376689
module.extra_5.conv2.bias module.extra_5.conv2.bias -0.45356037
module.extra_6.bn.weight module.extra_6.bn.weight -118.99042057
module.extra_6.bn.bias module.extra_6.bn.bias -1.05029858
module.extra_6.conv1.weight module.extra_6.conv1.weight 47.75907117
module.extra_6.conv1.bias module.extra_6.conv1.bias -0.30105668
module.extra_6.conv2.weight module.extra_6.conv2.weight 82.42883147
module.extra_6.conv2.bias module.extra_6.conv2.bias 0.24271000
module.extra_7.bn.weight module.extra_7.bn.weight -112.90572042
module.extra_7.bn.bias module.extra_7.bn.bias 2.30864563
module.extra_7.conv1.weight module.extra_7.conv1.weight 14.77395574
module.extra_7.conv1.bias module.extra_7.conv1.bias 0.08763358
module.extra_7.conv2.weight module.extra_7.conv2.weight 7.20600131
module.extra_7.conv2.bias module.extra_7.conv2.bias -0.28086568
module.d.0.weight module.d.0.weight -0.09255437
module.d.1.weight module.d.1.weight 8.70911378
module.d.1.bias module.d.1.bias 0.00000000
module.d.3.weight module.d.3.weight 0.05706205

As I have expected, the difference between these two models’ parameters for pre-trained section is 0 and it is good. I have got another checking for being definitely sure about frozen part. I have forwarded a random input image in these models and get the output of the frozen part and compare them. These section again proved the above claim that the frozen part’s parameters has not changed. But when I’ve changed the networks state to the evaluation mode using model1.train(False) and model0.train(False), the output were different. I don’t know what is the problem and how can I figure it out?

One thing to note is that my models do not have any Dorpout layers. But it has lots of Batch Normalization layers after each Conv layer. I also compared the running variance and running mean of same bn layers in these two models using following code and I found that they were equal:

model0_running_variance = []
model0_running_mean = []
for module in model0.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model0_running_mean.append(module.running_mean)
    	model0_running_variance.append(module.running_var)

model1_running_variance = []
model1_running_mean = []
for module in model1.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model1_running_mean.append(module.running_mean)
    	model1_running_variance.append(module.running_var)

print("running_mean difference")
for m1, m2 in zip(model0_running_mean, model1_running_mean):
	print("{0:.8f}".format((m1-m2).sum()))

print("running_var difference")
for m1, m2 in zip(mode0_running_variance, model1_running_variance):
	print("{0:.8f}".format((m1-m2).sum()))

and the output:

0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
running_var difference
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000
0.00000000

Best.

1 Like

@ptrblck, sorry man! Do you have any idea about my problem?

Let me understand the issue properly.

You have two models, each with the same pre-trained base model.
Comparing the pre-trained section yields exactly the same results.

Maybe a stupid question, but when you pass the random input, is it the same input for both networks?
Were the models in training mode when you passed the inputs?

After this test (which gives the same result for the random input), you switched both models to evaluation mode and the results differ, right?

Did you compare the running_mean and running_var after the test failed already?

Hey, Thanks for your response.

  1. yes. it is same for both networks
  2. yes. both are in the train mode.
  3. yes. yes the difference is in the evaluation mode.
  4. yes.

:frowning. Not know what is going on. please help.

Hmm, something is odd.
When you pushed the random input into your models in training mode, the BatchNorm layers should update its running_mean and running_var.
Since you used the same input, the updates should be the same in both “base” models, which you made sure with the last test.

The only chance I see is to post the model definitions and debug it.
Could you provide a minimal working example (most likely just the base model will be sufficient)?

Actually, it is an ongoing project in which I don’t have such a permission to share some part of it. If it is possible, please tell me what do you want to know (intermediate output or whatever else), I will share with you! :slight_smile:

Yeah it should update them, even if they are excluded from what optimizer can see. But it’s worth mentioning that DataParallel wrapper doesn’t properly update running statistics. https://github.com/pytorch/pytorch/issues/1051

1 Like

Could you tell us in detail how you froze the pretrained part?

yes sure. Using these following functions, I set the requires_grad to False or True based on my pre-trained section.

def release_weight(model):
    for param in model.parameters():
        param.requires_grad = True
    return model
def freeze_weight(model):
    for param in model.stage1.parameters():
        param.requires_grad = False
    for param in model.stage4.parameters():
        param.requires_grad = False
    for param in model.stage5.parameters():
        param.requires_grad = False       
    for param in model.stage5_1.parameters():
        param.requires_grad = False
    for param in model.parallel1.parameters():
        param.requires_grad = False
    for param in model.parallel2.parameters():
        param.requires_grad = False
    for param in model.stage7.parameters():
        param.requires_grad = False
    for param in model.final.parameters():
        param.requires_grad = False
    for param in model.d.parameters():
        param.requires_grad = False
    return model

Actually, when I instantiate the model, I call above two functions:

model = Model(pretrained_path, 425)
model = release_weight(model)
model = freeze_weight(model)

Then I select the parameters which has requires_gard=True using following commands and give them to optimizer:

parameters = itertools.filterfalse(lambda p: not p.requires_grad, model.parameters())
optimizer = optim.SGD(parameters, lr=baseLR, weight_decay=0.001, momentum=0.9)

I have the same problem. I’m not sure if it is for using DataParalell or what

@SimonW That’s good to know! Thanks for the link!

@mderakhshani Ok, I understand. Do you see the same error when you load both models from the checkpoints, set them to evaluation mode and pass a random sample through them?
Do you use any “special” layers, e.g. with conditions etc.?

@mfayyaz Are you able to share a small code snippet reproducing this issue?

Wow! It’s amazing!
Test your saving procedure and load your parameter again. I have some issue with the similar problem but it solved by checking save and load procedure.
Also, your definition of new layers would cause this problem.
I hope it gets resolved soon…

1 Like

Let me give you all of my possible permutation of networks. :smiley:

Rule1: pre-trained net: is a network without any training, just loaded the the pre-trained weights and initialized the extra layers.

Rule2: checkpoint: is a network initialized by the pre-trained weight and learnt on my dataset (no parameter updating on my pre-trained section - frozen part).

  1. pre-trained and pre-trained in training mode: the outputs are equal.
  2. pre-trained and pre-trained in evaluation mode: the outputs are equal.
  3. pre-trained and checkpoint0 in training mode: the outputs are equal.
  4. pre-trained and checkpoint0 in evaluation mode: the outputs are not equal.
  5. checkpoint0 and checkpoint1 in training mode: the outputs are equal.
  6. checkpoint0 and checkpoint1 in evaluation mode: the outputs are equal.

I did not have any special layer in my net. All of them are the regular layers such as Conv2d, Maxpool, BatchNorm and nothing else.

Hello @SimonW, @mfayyaz, @ptrblck, @Khalooei. I have prepared two scripts. one of them is used for training and the other for comparing the output. Here is my main file to overfit on a sample based on resnet18 pre-trained weight.

import torch
import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable as V
import torch.optim as optim


class MyModel(nn.Module):
	"""docstring for MyModel"""
	def __init__(self, resnet18):
		super(MyModel, self).__init__()
		self.features = nn.Sequential(*list(resnet18.children())[:-1])
		self.f1 = nn.Linear(2048, 4)

	def forward(self, input):
		out = self.features(input).detach()
		return self.f1(out.view(out.size(0),-1))
		
def release_weight(model):
    for param in model.parameters():
        param.requires_grad = True
    return model

def freeze_weight(model):
    for param in model.features.parameters():
        param.requires_grad = False
    return model    

resnet18 = models.resnet18(pretrained = True)
model = MyModel(resnet18)
model = release_weight(model)
model = freeze_weight(model)

parameters = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.SGD(parameters, lr=1e-3, weight_decay=0.001, momentum=0.9)

inp = V(torch.randn(1,3,256,256), requires_grad = False)
target = V(torch.randn(1,4))

for i in range(2):
	for j in range(100):
		optimizer.zero_grad()
		out = model(inp)
		loss = ((out - target) ** 2).mean()
		loss.backward()
		optimizer.step()
		print(loss.data[0])
	torch.save(model.state_dict(), './simpleModel{}.pth'.format(i))

and then here is the comparing code:

import torch
import torchvision.models as models
import torch.nn as nn
from torch.autograd import Variable as V

class MyModel(nn.Module):
    """docstring for MyModel"""
    def __init__(self, resnet18):
        super(MyModel, self).__init__()
        self.features = nn.Sequential(*list(resnet18.children())[:-1])
        self.f1 = nn.Linear(2048, 4)

    def forward(self, input):
        out1 = self.features(input)
        return out1, self.f1(out1.view(out1.size(0),-1))

ckpt0 = 'simpleModel0.pth'
ckpt1 = 'simpleModel1.pth'

resnet18 = models.resnet18(pretrained = True)
model0 = MyModel(resnet18)
model0.load_state_dict(torch.load(ckpt0), strict=True)
model0.train(False)


resnet18 = models.resnet18(pretrained = True)
model1 = MyModel(resnet18)
# model1.load_state_dict(torch.load(ckpt1))
model1.train(False)

model0_name = []
for name, param in model0.named_parameters():
    model0_name.append(name)

model1_name = []
for name, param in model1.named_parameters():
	model1_name.append(name)

diff = []
for p1, p2 in zip(model0.parameters(), model1.parameters()):
	diff.append((p1.data - p2.data).sum())

for n in zip(model0_name, model1_name, diff):
	print(n[0], n[1], "{0:.8f}".format(n[2]))

model0_running_variance = []
model0_running_mean = []
for module in model0.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model0_running_mean.append(module.running_mean)
    	model0_running_variance.append(module.running_var)

model1_running_variance = []
model1_running_mean = []
for module in model1.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model1_running_mean.append(module.running_mean)
    	model1_running_variance.append(module.running_var)

print("running_mean difference")
for m1, m2 in zip(model0_running_mean, model1_running_mean):
	print("{0:.8f}".format((m1-m2).sum()))

print("running_var difference")
for m1, m2 in zip(model0_running_variance, model1_running_variance):
	print("{0:.8f}".format((m1-m2).sum()))

inp = V(torch.randn(1,3,256,256))
pretrained_head0, out0 = model0(inp)
pretrained_head1, out1 = model1(inp)


diffpretrained_head = (pretrained_head1 - pretrained_head0).data.abs().sum()
diffout = (out1 - out0).data.abs().sum()

print(diffpretrained_head, diffout)

And again, these scripts proved my claim. Could you please check them up on your system?

Hi, I had the same problem a couple of months ago, I thought there is some thing wrong in my code!

1 Like

I debugged your code a bit and it seems that the BatchNorm layers differ.
You can’t see it, since you have a typo in saving the running_mean and running_var:

model0_running_variance = []
model0_running_mean = []
for module in model0.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model0_running_mean.append(module.running_mean)
    	model0_running_variance.append(module.running_var)

model1_running_variance = []
model1_running_mean = []
for module in model0.modules():
    if isinstance(module, nn.modules.BatchNorm2d):
    	model1_running_mean.append(module.running_mean)
    	model1_running_variance.append(module.running_var)

In both loops you iterate model0.
I will check, why they differ.

1 Like

The problem is that your BN layers differ.
I used the following code to solve the problem (just override the train function of your model):

    def train(self, mode=True, freeze_bn=False, freeze_bn_affine=False):
        
        super(MyModel, self).train(mode)
        if freeze_bn:
            for m in self.modules():
                if isinstance(m, nn.BatchNorm2d):
                    '''Freezing Mean/Var of BatchNorm2D'''
                    m.eval()
                    if freeze_bn_affine:
                        '''Freezing Weight/Bias of BatchNorm2D'''
                        m.weight.requires_grad = False
                        m.bias.requires_grad = False
5 Likes

Thanks. I have edited my reply! Sorry.

I’ve checked the training part and the running_* stats are being updated.
You have to set the BatchNorm layers to evaluation mode.

Add this to your training:

def freeze_bn(m):
    if isinstance(m, nn.BatchNorm2d):
        m.eval()
model.apply(freeze_bn)

EDIT: Too late, @mfayyaz was faster! :wink: Good catch!

3 Likes

Wooow. Thanks @mfayyaz. It was solved by your snippet.