To everyone that faces a similar problem, i.e. manipulating gradients in the backward step: I found a working solution for myself now by breaking the problem down into two steps.
import torch
import torch.nn.init as nnInit
from torch.autograd import Variable
# implementing the TRAINED TERNARY QUANTIZATION layer
# Zhu et al. ICLR 2017
# 1a) split the method in two functions: the first method TERNARIZES
# the weights; the second methods just calls the standard convolutions,
# i.e. it is NOT neccessary to implement this behavior
# 1b) implement the functionality forward and backward as
# torch.autograd.Function, that just manipulates the gradients that are
# passed back from the convolution layer
# 2) wrap these 2 functions in a torch.nn.Module
# write the first autograd function -> override forward & backward:
# pass the imagedata and its gradient through (left unchanged!) and just
# implement the weight manipulations: forward -> ternarize
# backward: do the custom gradient manipulation on weight gradients, that
# were generated by the standard convolution layer following afterwards
class ternarizeWeights(torch.autograd.Function):
def __init__(self, ternThres):
super(ternarizeWeights, self).__init__()
self.tThresScale = ternThres
def forward(self, img2Conv, fpWeights, tWp, tWn):
self.save_for_backward(img2Conv, fpWeights, tWp, tWn)
# compute the ternarized weights with the given threshold
quantW = fpWeights.clone()
quantThres = self.tThresScale * quantW.abs().max()
# normalize weights
quantW = quantW / quantW.abs().max()
# quantize to {-Wn, 0, Wp}
quantW[quantW.abs() <= quantThres] = 0.0
quantW[quantW < -1.0*quantThres] = -1.0*tWn[0]
quantW[quantW > quantThres] = tWp[0]
return img2Conv, quantW
def backward(self, grad_outputImg, grad_outputW):
img2Conv, fpWeights, tWp, tWn = self.saved_tensors
quantThres = self.tThresScale * fpWeights.abs().max()
# leave the image gradient unchanged
grad_img2Conv = grad_outputImg
# all other gradients have to be manipulated
# first clone them, to ensure the correct datatype and dimension
grad_fpWeights = fpWeights.clone()
grad_tWp = tWp.clone()
grad_tWn = tWn.clone()
grad_tWp[0] = grad_outputW[fpWeights > quantThres].sum()
grad_tWn[0] = grad_outputW[fpWeights < -1.0 * quantThres].sum()
grad_fpWeights[fpWeights > quantThres] = grad_outputW[fpWeights > quantThres] * tWp[0]
grad_fpWeights[fpWeights < -1.0 * quantThres] = grad_outputW[fpWeights < -1.0 * quantThres] * tWn[0]
grad_fpWeights[fpWeights.abs() <= quantThres] = 1.0 * grad_outputW[fpWeights.abs() <= quantThres]
return grad_img2Conv, grad_fpWeights, grad_tWp, grad_tWn
# write the second autograd function -> override forward
# call the standard convolution function on the unchanged image but
# with ternarized weights (2D case) and pass back their gradients
# -> the predeceding layer has to handle the gradients
class execute2DConvolution(torch.nn.Module):
def __init__(self, inStride=1, inPadding=0, inDilation=1, inGroups=1):
super(execute2DConvolution, self).__init__()
self.cStride = inStride
self.cPad = inPadding
self.cDil = inDilation
self.cGrp = inGroups
def forward(self, dataIn, weightIn):
return torch.nn.functional.conv2d(dataIn, weightIn, bias=None,
stride=self.cStride, padding=self.cPad,
dilation=self.cDil, groups=self.cGrp)
# wrap the two functions inside one module, so this appears as a
# customized convolution to the outside world as Conv1/2D
class ttqConv2D(torch.nn.Module):
def __init__(self, inWCin, inWCout, inWH,
inStride=1, inPad=0, inDil=1, inGroups=1, inTScale=0.05):
super(ttqConv2D, self).__init__()
# initialize all parameters that the convolution function needs to know
self.conStride = inStride
self.conPad = inPad
self.outPad = 0
self.conDil = inDil
self.conTrans = False
self.conGroups = inGroups
# initialize the weights and the bias as well as the
self.tScale = inTScale
self.fpWeight = torch.nn.Parameter(torch.Tensor(inWCout, inWCin, inWH, inWH))
# xavier weight initialization
nnInit.xavier_normal(self.fpWeight)
self.tWeightPos = torch.nn.Parameter(torch.Tensor([0]))
self.tWeightNeg = torch.nn.Parameter(torch.Tensor([0]))
nnInit.uniform(self.tWeightPos, 0.8, 1.2)
nnInit.uniform(self.tWeightNeg, 0.8, 1.2)
def forward(self, dataInput):
dInput, tWeights = ternarizeWeights(self.tScale)(dataInput, self.fpWeight,
self.tWeightPos, self.tWeightNeg)
return execute2DConvolution(self.conStride, self.conPad,
self.conDil, self.conGroups)(dInput, tWeights)
# just the same functionality as in the 2D case.
class execute1DConvolution(torch.nn.Module):
def __init__(self, inStride=1, inPadding=0, inDilation=1, inGroups=1):
super(execute1DConvolution, self).__init__()
self.cStride = inStride
self.cPad = inPadding
self.cDil = inDilation
self.cGrp = inGroups
def forward(self, dataIn, weightIn):
return torch.nn.functional.conv1d(dataIn, weightIn, bias=None,
stride=self.cStride, padding=self.cPad,
dilation=self.cDil, groups=self.cGrp)
class ttqConv1D(torch.nn.Module):
def __init__(self, inWCin, inWCout, inL,
inStride=1, inPad=0, inDil=1, inGroups=1, inTScale=0.05):
super(ttqConv1D, self).__init__()
# initialize all parameters that the convolution function needs to know
self.conStride = inStride
self.conPad = inPad
self.outPad = 0
self.conDil = inDil
self.conTrans = False
self.conGroups = inGroups
# initialize the weights and the bias as well as the
self.tScale = inTScale
self.fpWeight = torch.nn.Parameter(torch.Tensor(inWCout, inWCin, inL))
# xavier weight initialization
nnInit.xavier_normal(self.fpWeight)
self.tWeightPos = torch.nn.Parameter(torch.Tensor([0]))
self.tWeightNeg = torch.nn.Parameter(torch.Tensor([0]))
nnInit.uniform(self.tWeightPos, 0.8, 1.2)
nnInit.uniform(self.tWeightNeg, 0.8, 1.2)
def forward(self, dataInput):
dInput, tWeights = ternarizeWeights(self.tScale)(dataInput, self.fpWeight,
self.tWeightPos, self.tWeightNeg)
#print(self.fpWeight)
#print(tWeights)
return execute1DConvolution(self.conStride, self.conPad,
self.conDil, self.conGroups)(dInput, tWeights)