Opacus: How to disable backward hook temporally for multiple backward pass


I’m using Opacus for computing the per-sample gradient w.r.t the parameter. However, I also need to compute per-sample gradient of each logit w.r.t the input. Therefore I need to do back-propagation several times. A minimal example is as follows

import torch
from opacus.grad_sample import GradSampleModule
from torch.autograd import grad

class Model(torch.nn.Module):
    def __init__(self, inputSize, outputSize):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(inputSize, outputSize)

    def forward(self, x):
        out = self.linear(x)
        return out

num_classes = 2
bs = 16

model = Model(10,num_classes) 
op_model = GradSampleModule(model)

X = torch.rand(bs,10)
X.requires_grad = True
grad_x = torch.zeros(bs,num_classes,10)

output = op_model(X) # bs * num_classes
for c in range(num_classes):
    grad_x[:,c,:] = grad(outputs=output[:,c], inputs=X,\

The error shows

IndexError: pop from empty list

It seems that Opacus won’t work if the number of backprop is greater than the number of forward pass, according to this post. However, in my use case (adversarial training) at some point I need to do backprop several times to compute input gradient (not parameter’s gradient). I wonder if it is possible to disable the grad_sample functionality temporarily, and enable it afterwards. I appreciate any suggestions on this.