Can't reset parameters with torch.nn.utils.prune

Hi, I am trying to implement the lottery ticket hypothesis using torch.nn.utils.prune but I am having trouble trying to reset the model parameters. From what I understand, the pruning module adds a new parameter, weight_orig to each module by overriding weight. When I try to reinitialize the model parameters with model.apply(reinit_weights), the parameters corresponding to weight_orig in each module is not changing. Below is some code that demonstrates. Does someone have an explanation or potential workaround?

Thanks so much in advance.

# this allows the model to train with the new optimizer
# but doesn't reset parameters that have the mask added so fails the test_model_change
def reinit_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

def test_model_change(prev_iter_dict, model):
    for name, param in model.named_parameters():
        prev_param = prev_iter_dict[name]
        assert not torch.allclose(prev_param,param), 'model not updating'

model = LeNet5()
optimizer = optim.Adam(model.parameters(), lr=lr)
model.train().to(device)

pruner = L1Unstructured(0.0)
# adds masks of all ones to each of the layers
for n, m in model.named_modules():
    if isinstance(m, torch.nn.Conv2d):
        pruner.apply(m, name='weight', amount=0.0)
    if isinstance(m, torch.nn.Linear):
        pruner.apply(m, name='weight', amount=0.0)

train(model,1,train_loader,device,optimizer) # train for one epoch to test
full_cap = copy.deepcopy(model.state_dict())
model.apply(reinit_weights) # this reinitalizes the weight parameter but doesn't change the new_weight parameter
test_model_change(full_cap, model) # passes for conv0.bias but fails for conv0.weight

Personally, I simply save the weights I’d like to re-initialize to and then load the .pt file when its time to re-initialize.

Thank you for your response but this doesn’t work for me. This passes my test_model_change function so the weights do get reinitialized but when I try to retrain my model, the loss does not decrease. Interestingly, when I reset the model to the original weight parameters at the first initialization I don’t face this issue, only when I randomly try to set weights to some values.

def reset_weights_rand(model):
    rand_state_dict = pickle.load(open('rand_dict.pkl', 'rb'))
    for name, param in model.named_parameters():
        param.data = rand_state_dict[name]
rand_state_dict = copy.deepcopy(initial_state_dict)

for key,value in rand_state_dict.items():
    if 'bias' in key:
        init.normal_(value.data)
    else:
        init.xavier_normal_(value.data)        

with open(r"rand_dict.pkl", "wb") as output_file:
    pickle.dump(rand_state_dict, output_file)

train(model,1,train_loader,device,optimizer) # train for one epoch to test

full_cap = copy.deepcopy(model.state_dict())

reset_weights_rand(model)

# model.apply(reinit_weights) # this reinitalizes the weight parameter but doesn't change the new_weight parameter
test_model_change(full_cap, model) # passes now

# train 90% capacity for 10 more epochs
optimizer2 = torch.optim.Adam(model.parameters(), lr=lr) # problematic

train(model,5, train_loader,device,optimizer2)


As you can see the loss does not decrease even though no pruning as been done.

Hi,

I actually managed to hack a workaround with the tools in torch.nn.utils.prune. The trick is save the dictionary of masks and remove the pruning modules with remove and then add them back with CustomFromMask.

from torch.nn.utils.prune import l1_unstructured, remove, CustomFromMask,is_pruned


def add_masks(model,masks):
    mask_pruner = CustomFromMask(None)
    for module_name, module in model.named_modules():
        key = f"{module_name}.weight_mask"
        if key in masks:
            if isinstance(module, torch.nn.Conv2d):
                _mask = masks[key]
                mask_pruner.apply(module, 'weight', _mask)
            if isinstance(module, torch.nn.Linear):
                _mask = masks[key]
                mask_pruner.apply(module, 'weight', _mask)

def merge_masks(model):
    for n, m in model.named_modules():
        if is_pruned(m)==False:
            continue
        if isinstance(m, torch.nn.Conv2d):
            remove(m, name='weight')
        if isinstance(m, torch.nn.Linear):
            remove(m, name='weight')

With this fix, I can reinitialize my network with no issues. Hopefully this can help someone if they were dealing with my issue.

1 Like