I’m trying to create a custom module that essentially does the masking operation in this paper:
Interpretable CNNs
Essentially it’s a module that, on the forward pass, masks out the smaller activations of the input feature map, and on the backwards pass simply adds an additional loss which is approximately differentiable (so when doing gradient descent, it just transforms the grad_output a little). I think autograd is still trying to differentiate from the forwards pass, though, which is technically impossible because the forwards pass includes a usage of argmax.
Here’s what I have so far:
def mask(featMap,templates,neg_t):
max=torch.argmax(torch.flatten(featMap))
ind=torch.tensor((max/featMap.size()[0],max%featMap.size()[0]))
temp = templates[ind[0]][ind[1]]
masked = temp*featMap
return masked
def Loss_f(featMap,templates,neg_t,Z_T):
ret=0
n2 = float(featMap.size()[0]*featMap.size()[1])
alpha=n2/(1.+n2)
pr_t=alpha/n2
pr_tn=1.-alpha
pr_x=0
for i,template in enumerate(templates):
pr_x+=p_x(featMap,template)/Z_T[i]
pr_x*=pr_t
pr_x+=pr_tn*p_x(featMap,neg_t)/Z_T[-1]
for i,template in enumerate(templates):
tr=torch.sum(featMap*template)
ret+=pr_t*torch.exp(tr)*(tr-torch.log(Z_T[i])-torch.log(pr_x))/Z_T[i]
tr=torch.sum(featMap*neg_t)
ret+=pr_tn*torch.exp(tr)*(tr-torch.log(Z_T[-1])-torch.log(pr_x))/Z_T[-1]
return ret
class MaskFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input,templates,neg_t):
ret = torch.stack([torch.stack([mask(featMap,templates,neg_t) for featMap in samples]) for samples in input])
ctx.save_for_backward(input,ret,templates,neg_t)
return ret
@staticmethod
def backward(ctx,grad_output):
input,outs,templates,neg_t = ctx.saved_tensors()
input = torch.mean(input)
grad_input=None
if ctx.needs_input_grad[0]:
grad_input=torch.mean(grad_output)
lbda=1
Z_T=[get_Z(templates[i]) for i in range(templates.size()[0])].append(get_Z(neg_t))
grad_input+=torch.tensor([Loss_f(featMap,templates,neg_t,Z_T)*featMap*lbda for featMap in input])
return grad_input
class ConvMask(nn.Module):
def __init__(self, in_channels, size):
super(ConvMask,self).__init__()
self.in_channels = in_channels
self.templates,self.neg_t = self.init_templates(size)
def L1(self,x1,x2):
return abs(x1[0]-x2[0])+abs(x1[1]-x2[1])
def init_templates(self,size):
tau = 0.5/float(size*size)
beta=4
ret = torch.tensor([[[[tau*max(1-beta*self.L1((float(i),float(j)),(float(a),float(b)))/float(size),-1) for a in range(size)] for b in range(size)] for j in range(size)] for i in range(size)])
negT=torch.tensor([[-tau for a in range(size)] for b in range(size)])
return ret,negT
def forward(self,input):
return MaskFunc.apply(input,self.templates,self.neg_t)
But I get this error:
[code]
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [4, 512, 7, 7]], which is output 0 of MaskFuncBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
[\code]
when it’s unpacking the ctx.saved_tensors(). How can I, in a sense, “turn off” autograd for this module and “resume” it on the following module?