Prune.global_unstructured usability

I am finding prune.global_unstructured unnecessarily difficult to use - I have a torch.nn.Module that is a Sequential model of layers and other Sequential objects. I’d like to experiment with global pruning, so naturally starting point is to try pruning all parameters. To do so I need to pass a list of tuples of format [module, key] where key would be the name of a module parameter.

So I do:

parameters_to_prune = []

for name, val in model.named_parameters():
    parameters_to_prune.append((model, name))

Yet the implementation of prune.global_unstructured works not via torch.get_parameter, which would play nicely with keys from named_parameters, but tries getattr(model, name) which fails - the actual parameter lives under sequence of objects, like model[0].weight for '0.weight` parameter.

Is there a reason for this design? Is there a better way than writing a fancy recurrent function to find all leaf submodules of model?

Specific way I call prune and output:


def ConvMixer(dim, depth, kernel_size=5, patch_size=2, n_classes=10):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
                nn.GELU(),
                nn.BatchNorm2d(dim)
            )),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm2d(dim)
        ) for i in range(depth)],
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(dim, n_classes)
    )
model = ConvMixer(args.hdim, args.depth)

parameters_to_prune = []

for name, module in model.named_parameters():
    parameters_to_prune.append((model, name))

prune.global_unstructured(
    parameters=parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.05
)

""" Output
Traceback (most recent call last):
   File "python3.9/site-packages/torch/nn/utils/prune.py", line 1088, in global_unstructured
    [
  File "python3.9/site-packages/torch/nn/utils/prune.py", line 1089, in <listcomp>
    getattr(module, name + "_mask", torch.ones_like(getattr(module, name)))
  File "python3.9/site-packages/torch/nn/modules/module.py", line 1177, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Sequential' object has no attribute '0.weight'
"""
1 Like