Using nn.utils.prune causes torch.Tensor.__deepcopy__ to fail

When pruning a module using utilities in the torch.nn.utils.prune (following the official PyTorch Pruning tutorial), the “pruned” module becomes non-deep-copiable

Code to reproduce:

import copy
import torch
import torch.nn.utils.prune as prune

foo = torch.nn.Conv2d(2, 4, 1)
foo2 = copy.deepcopy(foo) # copy is successful before pruning

foo = prune.l1_unstructured(foo, name='weight', amount=0.2)
foo3 = copy.deepcopy(foo) # copy fails after pruning

The last line throws the following: RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment. because of this check.

Quick explanation of what the pruning call does:
1-) It reregisters the pruned parameter with the “_orig” suffix here
2-) It registers the pruning mask as a buffer using the “_mask” suffix here
3-) It reuses the original weight name with setattr(module, self._tensor_name, self.apply_mask(module)) here

Only the last line seems to put a tensor that is not a leaf on the module:

print(foo.weight_orig.is_leaf) # True
print(foo.weight_mask.is_leaf) # True
print(foo.weight.is_leaf) # False

My 10 second solution was to monkey-patch torch.Tensor.__deepcopy__ so that I can force a deepcopy by setting self.is_leaf=True and retry the deepcopy if it failed. My assumption was that forcing is_leaf=True has to do with properly copying gradients and minding the gradient tape. Since I intend to use the result of the deepcopy only for validation, I didn’t care about anything but forward pass correctness. (Edit: setting is_leaf is not allowed: AttributeError: attribute 'is_leaf' of 'torch._C._TensorBase' objects is not writable)

I am looking to understand the original motivation for the is_leaf check in torch.Tensor.deepcopy so that I can see whether my assumption was correct.

Finally, a safer but tedious alternative is to remove the pruning reparameterization just before deepcopy so that all the modifications done by the pruner on the module are removed and the pruning mask is fused into the original weights. This works but it means that pruning is not resumable after removing the parameterization.

It would be awesome if someone (especially @ptrblck) could propose a workaround or point out a mistake of mine to resolve this matter. Thanks in advance folks!

P.S. Not sure whether this is related yet but an earlier topic by @pinocchio was talking about issues with deepcopy and using setattr to set tensors as attributes of the module. setattr is used several times in torch.nn.utils.prune (e.g. here). It might just have to do with something else entirely.

My current solution is to prepare the pruned module for a successful deepcopy by deleting the non-leaf tensor as an attribute of the module (since it is neither a parameter nor a buffer of the module)

import copy
import torch
import torch.nn.utils.prune as prune

foo = torch.nn.Conv2d(2, 4, 1)
foo2 = copy.deepcopy(foo) # copy is successful before pruning

pruner = prune.L1Unstructured.apply(foo, name='weight', amount=0.2)
foo3 = copy.deepcopy(foo) # copy fails after pruning

print(pruner._tensor_name) # 'weight'
delattr(foo, pruner._tensor_name)
foo4 = copy.deepcopy(foo) # copy is successful

print(hasattr(foo4, pruner._tensor_name)) # False
# Run forward() once to recover the deleted attribute 
_ = foo4(torch.empty(1, 2, 4, 4))
print(hasattr(foo4, pruner._tensor_name)) # True

Note that the deleted attribute is reset during the first forward call on the module. This is because it is being reregistered at every call in the forward_pre_hook that the BasePruningMethod class implements here

cc: @Michela, any thoughts would be much appreciated for a longer term solution :slight_smile: