I am trying to train a sparsified network. I can construct that network using something like:
masks = get_masks(my_network, quantile) def prune_hook(module, input): for param, mask in zip(module.parameters(), masks): param.data = param.data * mask my_network.register_pre_forward_hook(prune_hook) # ~10% of weights will be nonzero during forward pass
This code snippet will guarantee that anytime I make a forward pass, the appropriate weights are “zeroed-out”. I’m wondering if PyTorch still performs the full gradient computation during each backward pass, or if there’s any way to inform PyTorch that it need only compute gradients for a subset of the weights. Part of the rationale of using a sparsified network is increased computational efficiency during training. I suppose I am wondering if that computational efficiency must be hand-rolled anew, or if it can be squeezed out of PyTorch.