How to re-set alll parameters in a network

How to re-set the weights for the entire network, using the original pytorch weight initialization

2 Likes

You could create a weight_reset function similar to weight_init and reset the weigths:

def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

model = = nn.Sequential(
    nn.Conv2d(3, 6, 3, 1, 1),
    nn.ReLU(),
    nn.Linear(20, 3)
)

model.apply(weight_reset)

Or alternatively just create a new instance of your model.

10 Likes

thank you! Solved my issue

Is it possible to reset only part of the weights of the model? For instance, the weights of an specific layer, or even some random weights of one layer?

2 Likes

You can use the weight method of your Conv or Linear to get these parameters…

This probably works too:

This is not a robust solution and wont work for anything except core torch.nn layers, but this works:

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

credit: Reinitializing the weights after each cross validation fold

3 Likes

Solution by @Brando_Miranda did not work for me. I had to add an extra loop:

for layers in model.children():
    for layer in layers:
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()
1 Like

@ptrblck I was looking for a more general way to reset all model parameters in case you don’t know the specific layers of the model. Would that do the trick?

def weight_reset(m):
    reset_parameters = getattr(m, "reset_parameters", None)
    if callable(reset_parameters):
        m.reset_parameters()

model.apply(weight_reset)
3 Likes

Your approach looks valid and I’m unsure if there is a better way.
You could use e.g. if isinstance(m, nn.Module), but this would also return True for custom modules, which don’t necessarily implement the reset_parameters() method.

3 Likes

is your response to alon?


btw related: python 3.x - Reset parameters of a neural network in pytorch - Stack Overflow

@Alon shouldn’t your code be recursive?

Here is the code with an example that runs:

def lp_norm(mdl: nn.Module, p: int = 2) -> Tensor:
    lp_norms = [w.norm(p) for name, w in mdl.named_parameters()]
    return sum(lp_norms)

def reset_all_weights(model: nn.Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # - check if the current module has reset_parameters & if it's callabed called it on m
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)


def reset_all_linear_layer_weights(model: nn.Module) -> nn.Module:
    """
    Resets all weights recursively for linear layers.

    ref:
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def init_weights(m):
        if type(m) == nn.Linear:
            m.weight.fill_(1.0)

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(init_weights)


def reset_all_weights_with_specific_layer_type(model: nn.Module, modules_type2reset) -> nn.Module:
    """
    Resets all weights recursively for linear layers.

    ref:
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def init_weights(m):
        if type(m) == modules_type2reset:
            # if type(m) == torch.nn.BatchNorm2d:
            #     m.weight.fill_(1.0)
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(init_weights)


# -- tests

def reset_params_test():
    import torchvision.models as models
    from uutils.torch_uu import lp_norm

    resnet18 = models.resnet18(pretrained=True)
    resnet18_random = models.resnet18(pretrained=False)

    print(f'{lp_norm(resnet18)=}')
    print(f'{lp_norm(resnet18_random)=}')
    print(f'{lp_norm(resnet18)=}')
    reset_all_weights(resnet18)
    print(f'{lp_norm(resnet18)=}')


if __name__ == '__main__':
    reset_params_test()
    print('Done! \a\n')

output:

lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18_random)=tensor(668.3687, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(476.0836, grad_fn=<AddBackward0>)
Done!

I am assuming this works because I calculated the norm twice for the pre-trained net and it was the same both times before calling reset.

Though I was unhappy it wasn’t closer to the norm of the random net I must admit but I think this is good enough.

related: python 3.x - Reset parameters of a neural network in pytorch - Stack Overflow

I find it better to just

del loss, model, optimizer
torch.cuda.empty_cache()

And then recreate the model and optimizer how you did in the first place.

Note that not deleting the loss resulted in cuda memory errors as it is connected to the model’s gradients graphs and thus keeps the whole thing in memory.

I think that model.apply() is already recursive.
See: Module — PyTorch 1.11.0 documentation

Why is the _reset_parameters method private (have an underscore in its name) for torch.nn.MultiheadAttention?

I’m unsure why some modules define it as a private method while others, such as nn.Linear, do not. @zhangguanheng66 might know more as nn.MultiHeadAttention was merged in this PR.