Pruning by manipulating weight_mask

Should one be able to enable (disable) pruning of specific pruned neurons by assigning 0 (1) to the corresponding elements of weight_mask? I’ve tried this script with torch 2.6 and 2.5.1.

import torch
import torch.nn.utils.prune as prune

weight = torch.arange(1, 5).repeat(1, 4).view(4, 4).float()

linear1 = torch.nn.Linear(4, 4, bias=False)
linear2 = torch.nn.Linear(4, 4, bias=False)

linear1.weight.data = weight.clone()
linear2.weight.data = weight.clone()

prune.ln_structured(linear1, name="weight", n=2, amount=0.5, dim=0)
assert linear1.weight_mask.count_nonzero() == 8
assert linear1.weight.count_nonzero() == 8

prune.identity(linear2, "weight")
linear2.weight_mask[:2].fill_(0)
assert linear2.weight_mask.count_nonzero() == 8 # pass
assert linear2.weight.count_nonzero() == 8      # fail

I figured this out. Updating weight_mask manually is only reflected in the weight tensor after calling the module, which will invoke the forward pre-hook registered by the pruning framework.

import torch
import torch.nn.utils.prune as prune

weight = torch.arange(1, 5).repeat(1, 4).view(4, 4).float()
linear2 = torch.nn.Linear(4, 4, bias=False)
linear2.weight.data = weight.clone()
prune.identity(linear2, "weight")
linear2.weight_mask[:2].fill_(0)
linear2(torch.randn(1, 4))
assert linear2.weight_mask.count_nonzero() == 8  # pass
assert linear2.weight.count_nonzero() == 8       # pass