Zero ing gradients makes the model very slow

I am implementing “Deconstructing lottery ticket” paper.
I am training Alexnet for several times.
In the first epoch, it has exactly the same weight of Alexnet.
In the second Epoch, I selected 20% of the weights(4M indexes in index_dict) and keep them as zero while training. To do this,
I defined hooks for all layers, and grad_clone[some weights] =0:
def my_hook4d_conv1(grad):
grad_clone = grad.clone()
for i in range(len(index_dict4d[“conv1”])):
a,b,c,d = index_dict4d[“conv1”][i]
grad_clone[a,b,c,d] = 0
return grad_clone
and then register hooks for all layers, before training started for the second epoch.

  1. This process makes the training very slow. (4M operation in the above for loop). Is there a way to speed up this process?