Torch.nn.utils.prune errors

Hi, I’m working with your tutorial for Pruning custom methods that use the abstract class link

However, when I use torch version == 1.6.0, there are some problems with module._parameters.

As I understand from the source codes from BasePruningMethod, it uses apply function to use compute_mask function.

However, when I call like Prunemethod.apply(module, name), then there is an error

cannot assign 'torch.floattensor' object to parameter 'weight_orig' (torch.nn.parameter or none required)

so I tried to inherit the apply function and change it into like this

# original line was at 164 in the prune.py github
# module.register_parameter(name + "_orig", orig)
module.register_parameter(name + "_orig", torch.nn.Parameter(orig))

But, after I changed it, there is another error with del module._parameters[name], there is no ‘weight’ attribute.

Is there any better idea than downgrading the torch vesion?

When I downgrade torch to 1.4.0, there is a problem with compute_mask paramter which doesn’t have ‘dim’ arguments.

Thanks in advance