Pruning torch.nn.modules.activation.multihead


I am trying to prune torch.nn.modules.activation.multihead, however the “in_proj_weight” and “q_proj_weight”, “k_proj_weight” and “v_proj_weight” matrices are not of the type Module but are of the type Parameter.

As such they don’t appear when using:
modules = model.named_modules()

And as far as I know, one can only prune modules.

Do I have to create a new Multi-Head implementation where I use a Key, Query and Value that are modules?
Or is there a workaround?


I can use ‘in_proj_weight’ etc, instead of ‘weight’.