How to reset variables' values in nn.Modules?

Hi, I want to implement cross-validation efficiently. Therefore, I want to build one model(/neural network) and then reset the variables randomly for each fold in the cross-validation.

Is there a way to reset variables’ values just the same as creating a new model?

Some example codes are

class Toy(torch.nn.Module):
	def __init__(self):
		super(Toy,self).__init__()
		self.fc = torch.nn.Linear(3,4)
	def forward(self,x):
		return self.fc(x)
		
toy = Toy()
cross_validation_acc = []
for fold_num in range(5):
	train(toy)
	cross_validation_acc.append(test(toy))
	reset_variables(toy)
print(np.mean(cross_validation_acc))

Thank you very much!

You could simply write an initialization function, which could look like this:

def reset_parameters(self):
    for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

and then call toy.reset_parameters() at the beginning of each fold.

Note: The example has been taken from the torchvision github repo and does not contain an initialization for torch.nn.Linear. You would have to adapt this for your usecase.

1 Like

Great. I will also search the default initialization methods for other modules. Thank you very much for your help!

Great! But if you don’t use SGD without momentum as optimizer, you would also have to reinitialize your optimizer’s state_dict for comparable results.

So, it’s like to firstly record the Optimizer.state_dict() and then use Optimizer.load_state_dict() to reset?

1 Like

yes, this should work.

1 Like

So, it’s like to firstly record the Optimizer.state_dict() and then use Optimizer.load_state_dict() to reset?

This is really good.