To ignore during Fine-tuing of a Pruned module in torch.nn.utils.prune

What I want to do
I am currently trying to Fine-Tuning after pruning a deep learning model using convolutional layers.
I want to ignore pruned modules when retraining.

What I did

  • I pruned a nn.Conv2d module with amount=1.0 using nn.utils.prune.random_structured(). Then retrain the model.
""" 各演算候補の定義 """
class SimpleModel(nn.Module):
  def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
    super(SimpleModel, self).__init__()
    self.process = nn.Sequential(
      nn.ReLU(inplace=False),
      nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
      nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding, bias=False),
      nn.BatchNorm2d(C_out, affine=affine)
    )
    self.classifier = nn.Linear(5*32*32, 10)

  def forward(self, x):
    x = self.process(x)
    x = self.classifier(x.view(x.size(0),-1))
    return x
  
model = SimpleModel(3, 5, 3, 1, 1).cuda()
optimizer = torch.optim.SGD(
      model.parameters(),
      lr=0.1,
      momentum=0.9,
      weight_decay=3e-4)
criterion = nn.CrossEntropyLoss().cuda()

for i, (name, module) in enumerate(model.named_modules()):
  if name == 'process.1':
    prune.random_structured(module, name='weight', amount=1.0, dim=1)
    # prune.remove(module, name='weight')

train_data = dataset.CIFAR10(root='../data', train=True, download=True, transform=transforms.ToTensor(),)
train_queue = torch.utils.data.DataLoader(train_data, batch_size=1)

for step, (data, target) in enumerate(train_queue):
    for i, (name, module) in enumerate(model.named_modules()):
    data, target = next(iter(train_queue))
    data = Variable(data).cuda()
    target = Variable(target).cuda()
    optimizer.zero_grad()

    logits = model(data.cuda())
    loss = criterion(logits, target)
    loss.backward()
    optimizer.step()

    
  • I Observed weight and weight_orig during experimental retraining.
    Then, when retraining, weight_orig was updated and weight remained zero.
    At that time the output of that module is completely zero
  • When nn.utils.prune.remove() was executed, weight_orig was removed and only weight remained.
    Retraining confirmed that weight was updated.

Here is a question.

  • If remove() is not done, is forward propagation computed using weight_orig or weight? And why is weight_orig updated?
    In fact, ideally, weight_orig should not be updated either, since we want BP to ignore the pruned module when retraining. (Because if weight_orig is being updated, it means that weight_orig is being used for gradient calculation, which is far from ideal. )

  • If remove() is executed, how do I stop updating the weights of the pruned module?

  • After pruning, how can I ignore the pruned module and retrain it?