Efficient gradient computation on sparsified/pruned network

I am trying to train a sparsified network. I can construct that network using something like:

masks = get_masks(my_network, quantile)
def prune_hook(module, input):
    for param, mask in zip(module.parameters(), masks):
        param.data = param.data * mask
# ~10% of weights will be nonzero during forward pass

This code snippet will guarantee that anytime I make a forward pass, the appropriate weights are “zeroed-out”. I’m wondering if PyTorch still performs the full gradient computation during each backward pass, or if there’s any way to inform PyTorch that it need only compute gradients for a subset of the weights. Part of the rationale of using a sparsified network is increased computational efficiency during training. I suppose I am wondering if that computational efficiency must be hand-rolled anew, or if it can be squeezed out of PyTorch.