How to re-set the weights for the entire network, using the original pytorch weight initialization
You could create a weight_reset
function similar to weight_init
and reset the weigths:
def weight_reset(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
m.reset_parameters()
model = = nn.Sequential(
nn.Conv2d(3, 6, 3, 1, 1),
nn.ReLU(),
nn.Linear(20, 3)
)
model.apply(weight_reset)
Or alternatively just create a new instance of your model.
thank you! Solved my issue
Is it possible to reset only part of the weights of the model? For instance, the weights of an specific layer, or even some random weights of one layer?
You can use the weight
method of your Conv
or Linear
to get these parameters…
This probably works too:
This is not a robust solution and wont work for anything except core torch.nn layers, but this works:
for layer in model.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
credit: Reinitializing the weights after each cross validation fold
Solution by @Brando_Miranda did not work for me. I had to add an extra loop:
for layers in model.children():
for layer in layers:
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
@ptrblck I was looking for a more general way to reset all model parameters in case you don’t know the specific layers of the model. Would that do the trick?
def weight_reset(m):
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
model.apply(weight_reset)
Your approach looks valid and I’m unsure if there is a better way.
You could use e.g. if isinstance(m, nn.Module)
, but this would also return True
for custom modules, which don’t necessarily implement the reset_parameters()
method.
is your response to alon?
btw related: python 3.x - Reset parameters of a neural network in pytorch - Stack Overflow
@Alon shouldn’t your code be recursive?
Here is the code with an example that runs:
def lp_norm(mdl: nn.Module, p: int = 2) -> Tensor:
lp_norms = [w.norm(p) for name, w in mdl.named_parameters()]
return sum(lp_norms)
def reset_all_weights(model: nn.Module) -> None:
"""
refs:
- https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
- https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
"""
@torch.no_grad()
def weight_reset(m: nn.Module):
# - check if the current module has reset_parameters & if it's callabed called it on m
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
model.apply(fn=weight_reset)
def reset_all_linear_layer_weights(model: nn.Module) -> nn.Module:
"""
Resets all weights recursively for linear layers.
ref:
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
"""
@torch.no_grad()
def init_weights(m):
if type(m) == nn.Linear:
m.weight.fill_(1.0)
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
model.apply(init_weights)
def reset_all_weights_with_specific_layer_type(model: nn.Module, modules_type2reset) -> nn.Module:
"""
Resets all weights recursively for linear layers.
ref:
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
"""
@torch.no_grad()
def init_weights(m):
if type(m) == modules_type2reset:
# if type(m) == torch.nn.BatchNorm2d:
# m.weight.fill_(1.0)
m.reset_parameters()
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
model.apply(init_weights)
# -- tests
def reset_params_test():
import torchvision.models as models
from uutils.torch_uu import lp_norm
resnet18 = models.resnet18(pretrained=True)
resnet18_random = models.resnet18(pretrained=False)
print(f'{lp_norm(resnet18)=}')
print(f'{lp_norm(resnet18_random)=}')
print(f'{lp_norm(resnet18)=}')
reset_all_weights(resnet18)
print(f'{lp_norm(resnet18)=}')
if __name__ == '__main__':
reset_params_test()
print('Done! \a\n')
output:
lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18_random)=tensor(668.3687, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(476.0836, grad_fn=<AddBackward0>)
Done!
I am assuming this works because I calculated the norm twice for the pre-trained net and it was the same both times before calling reset.
Though I was unhappy it wasn’t closer to the norm of the random net I must admit but I think this is good enough.
related: python 3.x - Reset parameters of a neural network in pytorch - Stack Overflow
I find it better to just
del loss, model, optimizer
torch.cuda.empty_cache()
And then recreate the model and optimizer how you did in the first place.
Note that not deleting the loss resulted in cuda memory errors as it is connected to the model’s gradients graphs and thus keeps the whole thing in memory.
I think that model.apply() is already recursive.
See: Module — PyTorch 1.11.0 documentation
Why is the _reset_parameters
method private (have an underscore in its name) for torch.nn.MultiheadAttention
?
I’m unsure why some modules define it as a private method while others, such as nn.Linear
, do not. @zhangguanheng66 might know more as nn.MultiHeadAttention
was merged in this PR.