Call backward on function inside a backpropagation step

To everyone that faces a similar problem, i.e. manipulating gradients in the backward step: I found a working solution for myself now by breaking the problem down into two steps.

import torch
import torch.nn.init as nnInit
from torch.autograd import Variable


# implementing the TRAINED TERNARY QUANTIZATION layer
# Zhu et al. ICLR 2017
# 1a)   split the method in two functions: the first method TERNARIZES
#       the weights; the second methods just calls the standard convolutions,
#       i.e. it is NOT neccessary to implement this behavior
# 1b)   implement the functionality forward and backward as
#       torch.autograd.Function, that just manipulates the gradients that are
#       passed back from the convolution layer
# 2)    wrap these 2 functions in a torch.nn.Module


# write the first autograd function -> override forward & backward:
# pass the imagedata and its gradient through (left unchanged!) and just
# implement the weight manipulations: forward -> ternarize
# backward: do the custom gradient manipulation on weight gradients, that
# were generated by the standard convolution layer following afterwards


class ternarizeWeights(torch.autograd.Function):

    def __init__(self, ternThres):
        super(ternarizeWeights, self).__init__()
        self.tThresScale = ternThres

    def forward(self, img2Conv, fpWeights, tWp, tWn):
        self.save_for_backward(img2Conv, fpWeights, tWp, tWn)
        # compute the ternarized weights with the given threshold

        quantW = fpWeights.clone()
        quantThres = self.tThresScale * quantW.abs().max()

        # normalize weights
        quantW = quantW / quantW.abs().max()
        # quantize to {-Wn, 0, Wp}
        quantW[quantW.abs() <= quantThres] = 0.0
        quantW[quantW < -1.0*quantThres] = -1.0*tWn[0]
        quantW[quantW > quantThres] = tWp[0]
        return img2Conv, quantW

    def backward(self, grad_outputImg, grad_outputW):
        img2Conv, fpWeights, tWp, tWn = self.saved_tensors
        quantThres = self.tThresScale * fpWeights.abs().max()
        # leave the image gradient unchanged
        grad_img2Conv = grad_outputImg
        # all other gradients have to be manipulated
        # first clone them, to ensure the correct datatype and dimension
        grad_fpWeights = fpWeights.clone()
        grad_tWp = tWp.clone()
        grad_tWn = tWn.clone()

        grad_tWp[0] = grad_outputW[fpWeights > quantThres].sum()
        grad_tWn[0] = grad_outputW[fpWeights < -1.0 * quantThres].sum()

        grad_fpWeights[fpWeights > quantThres] = grad_outputW[fpWeights > quantThres] * tWp[0]
        grad_fpWeights[fpWeights < -1.0 * quantThres] = grad_outputW[fpWeights < -1.0 * quantThres] * tWn[0]
        grad_fpWeights[fpWeights.abs() <= quantThres] = 1.0 * grad_outputW[fpWeights.abs() <= quantThres]
        return grad_img2Conv, grad_fpWeights, grad_tWp, grad_tWn


# write the second autograd function -> override forward
# call the standard convolution function on the unchanged image but
# with ternarized weights (2D case) and pass back their gradients
# -> the predeceding layer has to handle the gradients

class execute2DConvolution(torch.nn.Module):
    def __init__(self, inStride=1, inPadding=0, inDilation=1, inGroups=1):
        super(execute2DConvolution, self).__init__()
        self.cStride = inStride
        self.cPad = inPadding
        self.cDil = inDilation
        self.cGrp = inGroups

    def forward(self, dataIn, weightIn):
        return torch.nn.functional.conv2d(dataIn, weightIn, bias=None,
                                          stride=self.cStride, padding=self.cPad,
                                          dilation=self.cDil, groups=self.cGrp)


# wrap the two functions inside one module, so this appears as a
# customized convolution to the outside world as Conv1/2D

class ttqConv2D(torch.nn.Module):
    def __init__(self, inWCin, inWCout, inWH,
                 inStride=1, inPad=0, inDil=1, inGroups=1, inTScale=0.05):
        super(ttqConv2D, self).__init__()
        # initialize all parameters that the convolution function needs to know
        self.conStride = inStride
        self.conPad = inPad
        self.outPad = 0
        self.conDil = inDil
        self.conTrans = False
        self.conGroups = inGroups

        # initialize the weights and the bias as well as the
        self.tScale = inTScale
        self.fpWeight = torch.nn.Parameter(torch.Tensor(inWCout, inWCin, inWH, inWH))
        # xavier weight initialization
        nnInit.xavier_normal(self.fpWeight)
        self.tWeightPos = torch.nn.Parameter(torch.Tensor([0]))
        self.tWeightNeg = torch.nn.Parameter(torch.Tensor([0]))
        nnInit.uniform(self.tWeightPos, 0.8, 1.2)
        nnInit.uniform(self.tWeightNeg, 0.8, 1.2)

    def forward(self, dataInput):
        dInput, tWeights = ternarizeWeights(self.tScale)(dataInput, self.fpWeight,
                                                         self.tWeightPos, self.tWeightNeg)

        return execute2DConvolution(self.conStride, self.conPad,
                                    self.conDil, self.conGroups)(dInput, tWeights)


# just the same functionality as in the 2D case.

class execute1DConvolution(torch.nn.Module):
    def __init__(self, inStride=1, inPadding=0, inDilation=1, inGroups=1):
        super(execute1DConvolution, self).__init__()
        self.cStride = inStride
        self.cPad = inPadding
        self.cDil = inDilation
        self.cGrp = inGroups

    def forward(self, dataIn, weightIn):
        return torch.nn.functional.conv1d(dataIn, weightIn, bias=None,
                                          stride=self.cStride, padding=self.cPad,
                                          dilation=self.cDil, groups=self.cGrp)


class ttqConv1D(torch.nn.Module):
    def __init__(self, inWCin, inWCout, inL,
                 inStride=1, inPad=0, inDil=1, inGroups=1, inTScale=0.05):
        super(ttqConv1D, self).__init__()
        # initialize all parameters that the convolution function needs to know
        self.conStride = inStride
        self.conPad = inPad
        self.outPad = 0
        self.conDil = inDil
        self.conTrans = False
        self.conGroups = inGroups

        # initialize the weights and the bias as well as the
        self.tScale = inTScale
        self.fpWeight = torch.nn.Parameter(torch.Tensor(inWCout, inWCin, inL))
        # xavier weight initialization
        nnInit.xavier_normal(self.fpWeight)
        self.tWeightPos = torch.nn.Parameter(torch.Tensor([0]))
        self.tWeightNeg = torch.nn.Parameter(torch.Tensor([0]))
        nnInit.uniform(self.tWeightPos, 0.8, 1.2)
        nnInit.uniform(self.tWeightNeg, 0.8, 1.2)

    def forward(self, dataInput):
        dInput, tWeights = ternarizeWeights(self.tScale)(dataInput, self.fpWeight,
                                                         self.tWeightPos, self.tWeightNeg)
        #print(self.fpWeight)
        #print(tWeights)
        return execute1DConvolution(self.conStride, self.conPad,
                                    self.conDil, self.conGroups)(dInput, tWeights)
6 Likes