Temporarily bypassing autograd

I’m trying to create a custom module that essentially does the masking operation in this paper:
Interpretable CNNs
Essentially it’s a module that, on the forward pass, masks out the smaller activations of the input feature map, and on the backwards pass simply adds an additional loss which is approximately differentiable (so when doing gradient descent, it just transforms the grad_output a little). I think autograd is still trying to differentiate from the forwards pass, though, which is technically impossible because the forwards pass includes a usage of argmax.
Here’s what I have so far:

def mask(featMap,templates,neg_t):
  max=torch.argmax(torch.flatten(featMap))
  ind=torch.tensor((max/featMap.size()[0],max%featMap.size()[0]))
  temp = templates[ind[0]][ind[1]]
  masked = temp*featMap
  return masked

def Loss_f(featMap,templates,neg_t,Z_T):
  ret=0
  n2 = float(featMap.size()[0]*featMap.size()[1])
  alpha=n2/(1.+n2)
  pr_t=alpha/n2
  pr_tn=1.-alpha
  pr_x=0
  for i,template in enumerate(templates):
    pr_x+=p_x(featMap,template)/Z_T[i]
  pr_x*=pr_t
  pr_x+=pr_tn*p_x(featMap,neg_t)/Z_T[-1]
  for i,template in enumerate(templates):
    tr=torch.sum(featMap*template)
    ret+=pr_t*torch.exp(tr)*(tr-torch.log(Z_T[i])-torch.log(pr_x))/Z_T[i]
  tr=torch.sum(featMap*neg_t)
  ret+=pr_tn*torch.exp(tr)*(tr-torch.log(Z_T[-1])-torch.log(pr_x))/Z_T[-1]
  return ret


class MaskFunc(torch.autograd.Function):

  @staticmethod
  def forward(ctx, input,templates,neg_t):
    ret = torch.stack([torch.stack([mask(featMap,templates,neg_t) for featMap in samples]) for samples in input])
    ctx.save_for_backward(input,ret,templates,neg_t)
    return ret

  @staticmethod
  def backward(ctx,grad_output):
    input,outs,templates,neg_t = ctx.saved_tensors()
    input = torch.mean(input)
    grad_input=None
    if ctx.needs_input_grad[0]:
      grad_input=torch.mean(grad_output)
      lbda=1
      Z_T=[get_Z(templates[i]) for i in range(templates.size()[0])].append(get_Z(neg_t))
      grad_input+=torch.tensor([Loss_f(featMap,templates,neg_t,Z_T)*featMap*lbda for featMap in input])
    return grad_input



class ConvMask(nn.Module):

  def __init__(self, in_channels, size):
    super(ConvMask,self).__init__()
    self.in_channels = in_channels
    self.templates,self.neg_t = self.init_templates(size)

  def L1(self,x1,x2):
    return abs(x1[0]-x2[0])+abs(x1[1]-x2[1])

  def init_templates(self,size):
    tau = 0.5/float(size*size)
    beta=4
    ret = torch.tensor([[[[tau*max(1-beta*self.L1((float(i),float(j)),(float(a),float(b)))/float(size),-1) for a in range(size)] for b in range(size)] for j in range(size)] for i in range(size)])
    negT=torch.tensor([[-tau for a in range(size)] for b in range(size)])
    return ret,negT

  def forward(self,input):
    return MaskFunc.apply(input,self.templates,self.neg_t)

But I get this error:

[code]
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [4, 512, 7, 7]], which is output 0 of MaskFuncBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
[\code]
when it’s unpacking the ctx.saved_tensors(). How can I, in a sense, “turn off” autograd for this module and “resume” it on the following module?

I think I figured it out; the original topic is no longer super relevant anymore