ndronen
(Nicholas Dronen)
February 20, 2025, 2:21pm
1
Should one be able to enable (disable) pruning of specific pruned neurons by assigning 0 (1) to the corresponding elements of weight_mask? I’ve tried this script with torch 2.6 and 2.5.1.
import torch
import torch.nn.utils.prune as prune
weight = torch.arange(1, 5).repeat(1, 4).view(4, 4).float()
linear1 = torch.nn.Linear(4, 4, bias=False)
linear2 = torch.nn.Linear(4, 4, bias=False)
linear1.weight.data = weight.clone()
linear2.weight.data = weight.clone()
prune.ln_structured(linear1, name="weight", n=2, amount=0.5, dim=0)
assert linear1.weight_mask.count_nonzero() == 8
assert linear1.weight.count_nonzero() == 8
prune.identity(linear2, "weight")
linear2.weight_mask[:2].fill_(0)
assert linear2.weight_mask.count_nonzero() == 8 # pass
assert linear2.weight.count_nonzero() == 8 # fail
ndronen
(Nicholas Dronen)
March 15, 2025, 10:52pm
2
ndronen:
import torch
import torch.nn.utils.prune as prune
weight = torch.arange(1, 5).repeat(1, 4).view(4, 4).float()
linear1 = torch.nn.Linear(4, 4, bias=False)
linear2 = torch.nn.Linear(4, 4, bias=False)
linear1.weight.data = weight.clone()
linear2.weight.data = weight.clone()
prune.ln_structured(linear1, name="weight", n=2, amount=0.5, dim=0)
assert linear1.weight_mask.count_nonzero() == 8
assert linear1.weight.count_nonzero() == 8
prune.identity(linear2, "weight")
linear2.weight_mask[:2].fill_(0)
assert linear2.weight_mask.count_nonzero() == 8 # pass
assert linear2.weight.count_nonzero() == 8 # fail
I figured this out. Updating weight_mask
manually is only reflected in the weight
tensor after calling the module, which will invoke the forward pre-hook registered by the pruning framework.
import torch
import torch.nn.utils.prune as prune
weight = torch.arange(1, 5).repeat(1, 4).view(4, 4).float()
linear2 = torch.nn.Linear(4, 4, bias=False)
linear2.weight.data = weight.clone()
prune.identity(linear2, "weight")
linear2.weight_mask[:2].fill_(0)
linear2(torch.randn(1, 4))
assert linear2.weight_mask.count_nonzero() == 8 # pass
assert linear2.weight.count_nonzero() == 8 # pass