How to drop gradient information of `nn.Module`?

I’m playing some filter pruning tricks after training a model util convergence.

Firstly a model is trained with L1 regularization on BN weights.

Then I manually prune several kernels and BN parameters according to the BN weights:

conv.weight.data = conv.weight.data[pruned_index]
bn.weight.data = bn.weight.data[pruned_index]
bn.bias.data = bn.bias.data[pruned_index]
bn.running_mean.data = bn.running_mean.data[pruned_index]
bn.running_var.data = bn.running_var.data[pruned_index]

When I finetune the pruned model,
an error occured:
RuntimeError: Function CudnnBatchNormBackward returned an invalid gradient at index 1 - got [3] but expected shape compatible with [512]

Since the train -> prune -> finetune procedure is done in a single script, there
may be gradient information (shape of gradients and parameters) saved after training the model.

I think this is where this error comes from.

Currently I solve this problem by saving the model parameters to checkpoint.pth and then reload a new model from checkpoint, so that the gradient information is cleared:

model = resnet18()

# train model
train_model(model)

# save checkpoint
torch.save(model.state_dict(), "checkpoint.pth")

# reload
model = resnet18()
model.load_state_dict(torch.load("checkpoint"))

# prune
prune_model(model)

# finetune
finetune(model)

So is there any way to easily remove all gradient information?
My current solution is not somehow elegant.

I haven’t tested it, but manipulating the .data attribute might have weird side effects and is not recommended.
Could you instead warp the pruning into a torch.no_grad() block and reassign the pruned parameters?

Wrapping the pruning into torch.no_grad() solved my problem.
Thanks!