How to drop gradient information of `nn.Module`?

I’m playing some filter pruning tricks after training a model util convergence.

Firstly a model is trained with L1 regularization on BN weights.

Then I manually prune several kernels and BN parameters according to the BN weights: =[pruned_index] =[pruned_index] =[pruned_index] =[pruned_index] =[pruned_index]

When I finetune the pruned model,
an error occured:
RuntimeError: Function CudnnBatchNormBackward returned an invalid gradient at index 1 - got [3] but expected shape compatible with [512]

Since the train -> prune -> finetune procedure is done in a single script, there
may be gradient information (shape of gradients and parameters) saved after training the model.

I think this is where this error comes from.

Currently I solve this problem by saving the model parameters to checkpoint.pth and then reload a new model from checkpoint, so that the gradient information is cleared:

model = resnet18()

# train model

# save checkpoint, "checkpoint.pth")

# reload
model = resnet18()

# prune

# finetune

So is there any way to easily remove all gradient information?
My current solution is not somehow elegant.

I haven’t tested it, but manipulating the .data attribute might have weird side effects and is not recommended.
Could you instead warp the pruning into a torch.no_grad() block and reassign the pruned parameters?

Wrapping the pruning into torch.no_grad() solved my problem.