Reset model weights

I would like to know, if there is a way to reset weights for a PyTorch model.

Here is my code:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=5)

        self.pool = nn.MaxPool2d(5, stride=3)
        self.pool2 = nn.MaxPool2d(3, stride=1)
        self.activation = nn.ReLU()
        self.fc1 = nn.Linear(4096, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x):
        x = self.activation( self.pool(self.conv1(x)) )
        x = self.activation( self.pool(self.conv2(x)) )
        x = self.activation( self.pool(self.conv3(x)) )
        x = self.activation( self.pool2(self.conv4(x)) )
        x = x.view(-1, 64*8*8)

        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return x

I just to want average couple runs of the model, in order to evaluate it.

Any idea how can I do that?

1 Like

You could save the state_dict and load it for resetting the model. Have a look at the Serialization Semantics to see how to do it.
Would this work for you or do you want to re-initialize it to random weights?

3 Likes

thank you for the hint.

However, is there a way for the random re-initialization?

1 Like

Sure! You just have to define your init function:

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform(m.weight.data)

And call it on the model with:

model.apply(weight_init)

If you want to have the same random weights for each initialization, you would need to set the seed before calling this method with:

torch.manual_seed(your_seed)
13 Likes

This probably works too:


This is not a robust solution and wont work for anything except core torch.nn layers, but this works:

for layer in model.children():
   if hasattr(layer, 'reset_parameters'):
       layer.reset_parameters()

credit: Reinitializing the weights after each cross validation fold

2 Likes

could I perhaps double check if below is a robust solution? it is working - However I does seem useful to corroborate. I am looking for a solution that clears all weights between iterations of a hyper-parameter search. I am running individual models as sub-process’.

chk_dir = '/root/.cache/torch/hub/checkpoints/'

if os.path.isdir(chk_dir):
    for chkpnt in os.scandir(chk_dir):
        print(f'rm"ing {chkpnt.path}')
        os.system(f'rm {chkpnt.path}')

why doesn’t this work for you:

@unnir

Here is the code with an example that runs:

def lp_norm(mdl: nn.Module, p: int = 2) -> Tensor:
    lp_norms = [w.norm(p) for name, w in mdl.named_parameters()]
    return sum(lp_norms)

def reset_all_weights(model: nn.Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # - check if the current module has reset_parameters & if it's callabed called it on m
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)


def reset_all_linear_layer_weights(model: nn.Module) -> nn.Module:
    """
    Resets all weights recursively for linear layers.

    ref:
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def init_weights(m):
        if type(m) == nn.Linear:
            m.weight.fill_(1.0)

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(init_weights)


def reset_all_weights_with_specific_layer_type(model: nn.Module, modules_type2reset) -> nn.Module:
    """
    Resets all weights recursively for linear layers.

    ref:
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def init_weights(m):
        if type(m) == modules_type2reset:
            # if type(m) == torch.nn.BatchNorm2d:
            #     m.weight.fill_(1.0)
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(init_weights)


# -- tests

def reset_params_test():
    import torchvision.models as models
    from uutils.torch_uu import lp_norm

    resnet18 = models.resnet18(pretrained=True)
    resnet18_random = models.resnet18(pretrained=False)

    print(f'{lp_norm(resnet18)=}')
    print(f'{lp_norm(resnet18_random)=}')
    print(f'{lp_norm(resnet18)=}')
    reset_all_weights(resnet18)
    print(f'{lp_norm(resnet18)=}')


if __name__ == '__main__':
    reset_params_test()
    print('Done! \a\n')

output:

lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18_random)=tensor(668.3687, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(476.0836, grad_fn=<AddBackward0>)
Done!

I am assuming this works because I calculated the norm twice for the pre-trained net and it was the same both times before calling reset.

Though I was unhappy it wasn’t closer to the norm of the random net I must admit but I think this is good enough.

related: python 3.x - Reset parameters of a neural network in pytorch - Stack Overflow

@ptrblck @Brando_Miranda I was trying to do something with regards to resetting the weights and then applying a new tensor as my weights for a particular layer. I just had one doubt regarding the above discussion - does the reset_parameters() also clear all the associated memory that the layer occupies?

reset_parameters usually changes the parameters inplace as seen for e.g. nn.Linear so the memory usage shouldn’t change.