When pruning a module using utilities in the torch.nn.utils.prune
(following the official PyTorch Pruning tutorial), the “pruned” module becomes non-deep-copiable
Code to reproduce:
import copy
import torch
import torch.nn.utils.prune as prune
foo = torch.nn.Conv2d(2, 4, 1)
foo2 = copy.deepcopy(foo) # copy is successful before pruning
foo = prune.l1_unstructured(foo, name='weight', amount=0.2)
foo3 = copy.deepcopy(foo) # copy fails after pruning
The last line throws the following: RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
. because of this check.
Quick explanation of what the pruning call does:
1-) It reregisters the pruned parameter with the “_orig” suffix here
2-) It registers the pruning mask as a buffer using the “_mask” suffix here
3-) It reuses the original weight name with setattr(module, self._tensor_name, self.apply_mask(module))
here
Only the last line seems to put a tensor that is not a leaf on the module:
print(foo.weight_orig.is_leaf) # True
print(foo.weight_mask.is_leaf) # True
print(foo.weight.is_leaf) # False
My 10 second solution was to monkey-patch torch.Tensor.__deepcopy__
so that I can force a deepcopy by setting self.is_leaf=True
and retry the deepcopy if it failed. My assumption was that forcing is_leaf=True
has to do with properly copying gradients and minding the gradient tape. Since I intend to use the result of the deepcopy only for validation, I didn’t care about anything but forward pass correctness. (Edit: setting is_leaf
is not allowed: AttributeError: attribute 'is_leaf' of 'torch._C._TensorBase' objects is not writable
)
I am looking to understand the original motivation for the is_leaf check in torch.Tensor.deepcopy so that I can see whether my assumption was correct.
Finally, a safer but tedious alternative is to remove the pruning reparameterization just before deepcopy so that all the modifications done by the pruner on the module are removed and the pruning mask is fused into the original weights. This works but it means that pruning is not resumable after removing the parameterization.
It would be awesome if someone (especially @ptrblck) could propose a workaround or point out a mistake of mine to resolve this matter. Thanks in advance folks!
P.S. Not sure whether this is related yet but an earlier topic by @pinocchio was talking about issues with deepcopy
and using setattr
to set tensors as attributes of the module. setattr
is used several times in torch.nn.utils.prune
(e.g. here). It might just have to do with something else entirely.