How to modify and update weight and gradients of a few nodes after pruning while training?

Hi, I’m Newbie in pytorch.

Currently based on simple CNN architecture, I would like to apply pruning task (which holds only several nodes like dropout) in my Network layer-wisely. By making mask which can make the same size of the gradient of input/output(=activation map), I want to modify the weight or gradients using it during backpropagation step.

Therefore, my system flow I hope is as follows,

  1. forward pass,
  2. loss calculation,
  3. making mask in each layer by using weights(W), bias(b), gradients of input(dL/dx)/output(dL/dy), etc.,
  4. modifying gradients using mask which consist of 1 (which indicates meaningful nodes) or 0 (meaningless),
  5. updating(or applying) weight or gradient using the modified gradients layer-wisely

I have already read other similar Q&A about this question in pytorch community and tried to apply it as below,

Case 1) Use of register_backward_hook: in this case, there is only an example to use multiplication of scalar variable (like gradients * 10). Of course, I could get the gradient of input (dL/dx) or output(dL/dy) which has same size of mask but I don’t know how to update the modified the gradient manually.

Case 2) Making my own function class using torch.autograd.Function as belows. However, 1) I couldn’t see any example in convolutional case (We can’t debug in detail of conv layer) and 2) what I want to do is just to apply the pre-calculated mask to the gradients and in this respect it is too difficult to understand.

   class my_function(torch.autograd.Function):
       @staticmethod
       def forward(ctx, input, weight, bias):
           # reimplement forward, stash by ctx.save what you need
           ctx.save_for_backward(input, weight, bias)
           if input.dim() == 2 and bias is not None:
               # fused op is marginally faster
               return torch.addmm(bias, input, weight.t())

           output = input.matmul(weight.t())
           if bias is not None:
               output += bias
           return output

       @staticmethod
       def backward(ctx, grad_output):
           # print('hello_backward')
           input, weight, bias = ctx.saved_variables
           # print(input.shape)
           grad_input = grad_weight = grad_bias = None

           '''
           if ctx.needs_input_grad[0]:
             grad_input = grad_output.mm(weight)
           if ctx.needs_input_grad[1]:
             grad_weight = grad_output.t().mm(input)
           if bias is not None and ctx.needs_input_grad[2]:
             grad_bias = grad_output.sum(0).squeeze(0)
           '''

           # not sure what to return as format
           return grad_input, grad_weight, grad_bias

Now I refer to the paper[https://arxiv.org/abs/1706.06197] in ICML2017 which do Sparsified Back Propagation.
So, it would be very helpful for any advice how to implement this concept. Thank you!