Reset model weights

I would like to know, if there is a way to reset weights for a PyTorch model.

Here is my code:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=5)

        self.pool = nn.MaxPool2d(5, stride=3)
        self.pool2 = nn.MaxPool2d(3, stride=1)
        self.activation = nn.ReLU()
        self.fc1 = nn.Linear(4096, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.activation( self.pool(self.conv1(x)) )
        x = self.activation( self.pool(self.conv2(x)) )
        x = self.activation( self.pool(self.conv3(x)) )
        x = self.activation( self.pool2(self.conv4(x)) )
        x = x.view(-1, 64*8*8)

        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

I just to want average couple runs of the model, in order to evaluate it.

Any idea how can I do that?

1 Like

You could save the state_dict and load it for resetting the model. Have a look at the Serialization Semantics to see how to do it.
Would this work for you or do you want to re-initialize it to random weights?

3 Likes

thank you for the hint.

However, is there a way for the random re-initialization?

1 Like

Sure! You just have to define your init function:

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform(m.weight.data)

And call it on the model with:

model.apply(weight_init)

If you want to have the same random weights for each initialization, you would need to set the seed before calling this method with:

torch.manual_seed(your_seed)
13 Likes

This probably works too:


This is not a robust solution and wont work for anything except core torch.nn layers, but this works:

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

credit: Reinitializing the weights after each cross validation fold

2 Likes

could I perhaps double check if below is a robust solution? it is working - However I does seem useful to corroborate. I am looking for a solution that clears all weights between iterations of a hyper-parameter search. I am running individual models as sub-process’.

chk_dir = '/root/.cache/torch/hub/checkpoints/'

if os.path.isdir(chk_dir):
    for chkpnt in os.scandir(chk_dir):
        print(f'rm"ing {chkpnt.path}')
        os.system(f'rm {chkpnt.path}')

why doesn’t this work for you:

@unnir

Here is the code with an example that runs:

def lp_norm(mdl: nn.Module, p: int = 2) -> Tensor:
    lp_norms = [w.norm(p) for name, w in mdl.named_parameters()]
    return sum(lp_norms)

def reset_all_weights(model: nn.Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # - check if the current module has reset_parameters & if it's callabed called it on m
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)


def reset_all_linear_layer_weights(model: nn.Module) -> nn.Module:
    """
    Resets all weights recursively for linear layers.

    ref:
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def init_weights(m):
        if type(m) == nn.Linear:
            m.weight.fill_(1.0)

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(init_weights)


def reset_all_weights_with_specific_layer_type(model: nn.Module, modules_type2reset) -> nn.Module:
    """
    Resets all weights recursively for linear layers.

    ref:
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def init_weights(m):
        if type(m) == modules_type2reset:
            # if type(m) == torch.nn.BatchNorm2d:
            #     m.weight.fill_(1.0)
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(init_weights)


# -- tests

def reset_params_test():
    import torchvision.models as models
    from uutils.torch_uu import lp_norm

    resnet18 = models.resnet18(pretrained=True)
    resnet18_random = models.resnet18(pretrained=False)

    print(f'{lp_norm(resnet18)=}')
    print(f'{lp_norm(resnet18_random)=}')
    print(f'{lp_norm(resnet18)=}')
    reset_all_weights(resnet18)
    print(f'{lp_norm(resnet18)=}')


if __name__ == '__main__':
    reset_params_test()
    print('Done! \a\n')

output:

lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18_random)=tensor(668.3687, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(476.0836, grad_fn=<AddBackward0>)
Done!

I am assuming this works because I calculated the norm twice for the pre-trained net and it was the same both times before calling reset.

Though I was unhappy it wasn’t closer to the norm of the random net I must admit but I think this is good enough.

related: python 3.x - Reset parameters of a neural network in pytorch - Stack Overflow

@ptrblck @Brando_Miranda I was trying to do something with regards to resetting the weights and then applying a new tensor as my weights for a particular layer. I just had one doubt regarding the above discussion - does the reset_parameters() also clear all the associated memory that the layer occupies?

reset_parameters usually changes the parameters inplace as seen for e.g. nn.Linear so the memory usage shouldn’t change.

How is reset_parameters different from re-instantiate the model instance?

.reset_parameters() will reset the parameters inplace, such that the actual parameters are the same objects but their values will be manipulated.
This would allow you to use the same optimizer etc. in case you’ve already passed the parameters to it.
If you are creating a new module, you would of course also reset the parameters, but these parameters are new objects which you might need to pass to an optimizer again (depending on your actual use case).

1 Like