Same input, same weights, different output

Good morning,

I am having an issue finetuning a model, the feature vector changes when it should not!.

I will show two cases, the first one working as expected but not the second one:

Considerations for both cases (for debugging):

  • Dataloader outputs images in the same exact order every epoch
  • All layers are frozen except the last linear (requires_grad of all layers except the last one = False)
class myLoss():
	def __init__(self):
		self.crossEntropy = nn.CrossEntropyLoss()

	def __call__(self, outputs, targets):
		return  self.crossEntropy(outputs[0], targets)

def set_parameter_requires_grad(model):
	for param in model.parameters():
		param.requires_grad = False

First case (obtaining expected result = feature vector does not change over epochs for the same input):

class myModel(nn.Module):
	def __init__(self, args):
		super(myModel, self).__init__()
		
		self.backbone = nn.Sequential(
			nn.Conv2d(3, 6, 5),
			nn.ReLU(),
			nn.MaxPool2d(2, 2),
			nn.Conv2d(6, 16, 5),
			nn.ReLU(),
			nn.MaxPool2d(2, 2)
			)
		set_parameter_requires_grad(self.backbone)
		self.fc = nn.Linear(13456,2)
		
	def forward(self, x):
		batchSize = x.shape[0]
		x = self.backbone(x)
		feat = x.view(batchSize, -1)
		x = F.relu(self.fc(feat))
		return x, feat

When printing “feat” the output does not change across epochs, for the same input, feat is the same.

Second case (NOT obtaining expected result = feature vector change over epochs for the same input):

class myModel(nn.Module):
	def __init__(self, args):
		super(myModel, self).__init__()
		model = models.resnet18(pretrained=args.preTrained)
		modules=list(model.children())[:-1]
		self.backbone = nn.Sequential(*modules)
		set_parameter_requires_grad(self.backbone)
		self.fc = nn.Linear(512, 2)
		
	def forward(self, x):
		feat = self.backbone(x).squeeze()
		x = self.fc(feat)
		return x, feat

I am instantiating the model as:

featureExtractor = backbone(self.args)
net = myModel(self.args, featureExtractor)
output, feat = net(input)
print(feat)

In this case, when printing “feat” the output, it does change across epochs, for the same input, feat varies…

I have built a function that checks if the weights (before the last fc) has changed, the function indicates that the weights have not changed… however, the feat changes…

if epoch==0:
	self.first_net = copy.deepcopy(net)

else:
	for param in zip(self.first_net.named_parameters(), net.parameters()):
		p1,p2 = param
		if 'fc' not in p1[0]:
			if p1[1].data.ne(p2.data).sum() > 0:
				print('NETWORK HAS CHANGED THE WEIGHTS in layer {}'.format(p1[0]))

Am I missing something?

Thank you in advance!

I get the same output using resnet18:

class myModel(nn.Module):
	def __init__(self):
		super(myModel, self).__init__()
		model = models.resnet18(pretrained=True)
		modules=list(model.children())[:-1]
		self.backbone = nn.Sequential(*modules)
		self.fc = nn.Linear(512, 2)
		
	def forward(self, x):
		feat = self.backbone(x).squeeze()
		x = self.fc(feat)
		return x, feat

model = myModel()

x = torch.randn(1, 3, 224, 224)
out, feat = model(x)
for _ in range(10):
    out_tmp, feat_tmp = model(x)
    print(torch.allclose(out, out_tmp))
    print(torch.allclose(feat, feat_tmp))

That being said, since the model is in training mode, the internal statistics of batchnorm layers will be updated (and dropout will be used if defined in the model).
To get the same outputs, it’s recommended to use model.eval() to disable dropout and use the running estimates of all batchnorm layers.

1 Like

Hi, @ptrblck, thank you for your great answer. But may I ask why BN layer will change the model behavior? To on the same page, I am referring the following case:

model.train()
output1 = model(input)
output2 = model(input)

Will output1 and output2 be different?

output1 and output2 might be different, if dropout is applied.
The internal running stats of all batchnorm layers will be updated in both forward passes.
This would mean, that the output after calling model.eval() would give a different result:

model.train()
output1 = model(input)

model.eval()
output1_eval = model(input)

model.train()
output2 = model(input)

model.eval()
model2_eval = model(input)

Here modelX_eval would be different even though dropout is deactivated, since the batchnorm stats were updated.

@ptrblck Thank you for your patience and kindness. May I ask that in my case, if only BN layers involved instead of dropout. Are the outputs are expected to be the same?

Yes, if no dropout (or any custom layers, which change the behavior using the self.training flag) are used, then the output should at least yield a max. absolute error of the floating point precision limit.

1 Like