Prune.global_unstructured usability

I am finding prune.global_unstructured unnecessarily difficult to use - I have a torch.nn.Module that is a Sequential model of layers and other Sequential objects. I’d like to experiment with global pruning, so naturally starting point is to try pruning all parameters. To do so I need to pass a list of tuples of format [module, key] where key would be the name of a module parameter.

So I do:

parameters_to_prune = []

for name, val in model.named_parameters():
    parameters_to_prune.append((model, name))

Yet the implementation of prune.global_unstructured works not via torch.get_parameter, which would play nicely with keys from named_parameters, but tries getattr(model, name) which fails - the actual parameter lives under sequence of objects, like model[0].weight for '0.weight` parameter.

Is there a reason for this design? Is there a better way than writing a fancy recurrent function to find all leaf submodules of model?

Specific way I call prune and output:

def ConvMixer(dim, depth, kernel_size=5, patch_size=2, n_classes=10):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
                nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
            nn.Conv2d(dim, dim, kernel_size=1),
        ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Linear(dim, n_classes)
model = ConvMixer(args.hdim, args.depth)

parameters_to_prune = []

for name, module in model.named_parameters():
    parameters_to_prune.append((model, name))

    parameters=parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.05

""" Output
Traceback (most recent call last):
   File "python3.9/site-packages/torch/nn/utils/", line 1088, in global_unstructured
  File "python3.9/site-packages/torch/nn/utils/", line 1089, in <listcomp>
    getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
  File "python3.9/site-packages/torch/nn/modules/", line 1177, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Sequential' object has no attribute '0.weight'
1 Like