I am finding prune.global_unstructured
unnecessarily difficult to use - I have a torch.nn.Module that is a Sequential model of layers and other Sequential objects. I’d like to experiment with global pruning, so naturally starting point is to try pruning all parameters. To do so I need to pass a list of tuples of format [module, key] where key would be the name of a module parameter.
So I do:
parameters_to_prune = []
for name, val in model.named_parameters():
parameters_to_prune.append((model, name))
Yet the implementation of prune.global_unstructured
works not via torch.get_parameter, which would play nicely with keys from named_parameters
, but tries getattr(model, name)
which fails - the actual parameter lives under sequence of objects, like model[0].weight for '0.weight` parameter.
Is there a reason for this design? Is there a better way than writing a fancy recurrent function to find all leaf submodules of model?
Specific way I call prune and output:
def ConvMixer(dim, depth, kernel_size=5, patch_size=2, n_classes=10):
return nn.Sequential(
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
nn.GELU(),
nn.BatchNorm2d(dim),
*[nn.Sequential(
Residual(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
nn.GELU(),
nn.BatchNorm2d(dim)
)),
nn.Conv2d(dim, dim, kernel_size=1),
nn.GELU(),
nn.BatchNorm2d(dim)
) for i in range(depth)],
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(dim, n_classes)
)
model = ConvMixer(args.hdim, args.depth)
parameters_to_prune = []
for name, module in model.named_parameters():
parameters_to_prune.append((model, name))
prune.global_unstructured(
parameters=parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.05
)
""" Output
Traceback (most recent call last):
File "python3.9/site-packages/torch/nn/utils/prune.py", line 1088, in global_unstructured
[
File "python3.9/site-packages/torch/nn/utils/prune.py", line 1089, in <listcomp>
getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
File "python3.9/site-packages/torch/nn/modules/module.py", line 1177, in __getattr__
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Sequential' object has no attribute '0.weight'
"""