I’m playing some filter pruning tricks after training a model util convergence.
Firstly a model is trained with L1 regularization on BN weights.
Then I manually prune several kernels and BN parameters according to the BN weights:
conv.weight.data = conv.weight.data[pruned_index]
bn.weight.data = bn.weight.data[pruned_index]
bn.bias.data = bn.bias.data[pruned_index]
bn.running_mean.data = bn.running_mean.data[pruned_index]
bn.running_var.data = bn.running_var.data[pruned_index]
When I finetune the pruned model,
an error occured:
RuntimeError: Function CudnnBatchNormBackward returned an invalid gradient at index 1 - got [3] but expected shape compatible with [512]
Since the train -> prune -> finetune procedure is done in a single script, there
may be gradient information (shape of gradients and parameters) saved after training the model.
I think this is where this error comes from.
Currently I solve this problem by saving the model parameters to checkpoint.pth
and then reload a new model from checkpoint, so that the gradient information is cleared:
model = resnet18()
# train model
train_model(model)
# save checkpoint
torch.save(model.state_dict(), "checkpoint.pth")
# reload
model = resnet18()
model.load_state_dict(torch.load("checkpoint"))
# prune
prune_model(model)
# finetune
finetune(model)
So is there any way to easily remove all gradient information?
My current solution is not somehow elegant.