Call backward on function inside a backpropagation step

Hi everyone!

I’m trying to build a custom module layer which itself uses a custom function. Then, inside this function it would be nice, if I could use existing functions. As a simplified example I wrapped a Linear Layer inside my function and try to pass its weights as a parameter from the “surrounding” module.
Originally, I asked this as a follow up question, but I think it’s easier to find this topic for related issues, when it is posted as a “stand-alone-problem”.

In the forward pass everything seems to work out fine, but when it comes to the backward computations, the backward computation inside the linFct.backward() method never seems to terminate. More precisely, after passing the gradient value via
tmpLin(tmpDataVar).backward(grad_output) nothing more seems happens when the Variables backward method calls the execution_engine.

import torch


class linFct(torch.autograd.Function):
    def forward(self, fctDataIn, fctWeight):
        self.save_for_backward(fctDataIn, fctWeight)
        tmpDataVar = torch.autograd.Variable(fctDataIn)
        tmpWeightParam = torch.nn.Parameter(fctWeight)
        tmpLin = torch.nn.Linear(3, 2, bias=False)
        tmpLin.weight = tmpWeightParam
        outFct = tmpLin(tmpDataVar)
        return outFct.data

    def backward(self, grad_output):
        fctDataIn, fctWeight = self.saved_tensors
        tmpDataVar = torch.autograd.Variable(fctDataIn, requires_grad=True)
        tmpWeightParam = torch.nn.Parameter(fctWeight)
        tmpLin = torch.nn.Linear(3, 2, bias=False)
        tmpLin.weight = tmpWeightParam
        tmpLin.zero_grad()
        print(tmpDataVar.data)
        print(tmpWeightParam.data)
        print(grad_output)
        print("still here...")
        tmpLin(tmpDataVar).backward(grad_output)
        print("cannot reach this :( ")
        grad_fctDataIn = tmpDataVar.grad.data
        grad_fctWeight = tmpWeightParam.grad.data
        print(grad_fctDataIn)
        print(grad_fctWeight)
        return grad_fctDataIn, grad_fctWeight


class linLayer(torch.nn.Module):
    def __init__(self):
        super(linLayer, self).__init__()
        self.wParam = torch.nn.Parameter(torch.randn(2, 3))
        self.fct = linFct()

    def forward(self, layerIn):
        return self.fct(layerIn, self.wParam)

x = torch.autograd.Variable(torch.randn(2, 3), requires_grad=True)
fct = linLayer()
print("forward...")
y = fct(x)
fct.zero_grad()
print("backward...")
fct(x).backward(torch.randn(2, 2))
print(x.grad.data)
print(fct.wParam.grad.data)

I assume, that it’s not the way, how autograd should be used and that maybe it’s not allowed to run “backward” computations while autograd traverses the backward graph. I would be grateful for any advice how to implement the use of a backward function call inside a self-implemented backward method. Thank you!

1 Like

In the case that this functionality is not supported now, are you planning to add it in future releases @smth @apaszke?

Are there any ideas for a workaround?
Maybe something more to the background of this problem: I’m trying to implement the weight update functionality of the “Trained Ternary Quantization”-Paper (https://arxiv.org/abs/1612.01064). So I thought a module which contains the continuous weights as well as the quantized weights as parameters as the “wrapper” module would fit my needs and could then call a convolution-function with this custom backward behaviour.
Thanks again for anyone who’s going to take a look and shares some thoughts!

Hi again,
meanwhile, I opened an issue on GitHub (https://github.com/pytorch/pytorch/issues/1776) and it seems that there is no easy solution for this.
However, now I’m interested in an “easy” way to just call the “backward” methods which return the computed gradient to the input and the parameters. To make this more clear, I don’t want to reinvent the wheel and implement my own backward_convolution, since these functionalities are already implemented in the framework.
Therefore, what would be the most convenient way to get a functionality like

gradInput = conv_backward(Input, weight, bias, gradOutput)
and
gradWeight, gradBias = conv_backward(Input, weight, bias, gradOutput)
?
Because this would for now at least solve my problems.

1 Like

After looking deeper into the code, I built a workaround for the convolution function, since I need to tune the gradients of the weights. I think it’s a little messy, maybe someone wants to share some thoughts regarding the following (working!) snippet:

import torch
from torch.autograd import Variable
import torch.nn._functions as tnnf
import time


class exploreConv(torch.autograd.Function):
    def __init__(self, inStride=1, inPad=0, inDil=1, inGroups=1, imgDim=2):
        super(exploreConv, self).__init__()
        self.conStride = ()
        self.conPad = ()
        self.conDil = ()
        self.conGroups = inGroups
        for k in range(imgDim):
            self.conStride = self.conStride + (inStride,)
            self.conPad = self.conPad + (inPad,)
            self.conDil = self.conDil + (inDil,)

        self.convFct = tnnf.conv.ConvNd(self.conStride, self.conPad, self.conDil,
                                        False, (0, 0), self.conGroups)

    def forward(self, inImg, inKernel, inBias=None):
        self.save_for_backward(inImg, inKernel, inBias)
        self.convFct.requires_grad = True
        return self.convFct.forward(inImg, inKernel, inBias)

    def backward(self, grad_output):
        inImg, inKernel, inBias = self.saved_tensors
        if inBias != None:
            self.convFct.needs_input_grad = (True, True, True)
        else:
            self.convFct.needs_input_grad = (True, True, False)

        # for surveillance purpose -> time wait
        time.sleep(0)
        gradIn = self.convFct._grad_input(inImg, inKernel, grad_output)
        gradWeight, gradBias = self.convFct._grad_params(inImg, inKernel, inBias, grad_output)
        return gradIn, gradWeight, gradBias


xBSz, xChan, xH, xW = 50, 2, 512, 512
kOutChan, kInChan, kH, kW = 1, 2, 5, 5
pad = 0
convStride = 2
xImgDim = 2
x = torch.randn(xBSz, xChan, xH, xW)
k = torch.randn(kOutChan, kInChan, kH, kW)
yH = (xH - kH + 2*pad)/convStride + 1
yW = (xW - kW + 2*pad)/convStride + 1

cudaFlag = True
if cudaFlag:
    xV1 = Variable(x.cuda(), requires_grad=True)
    kV1 = Variable(k.cuda(), requires_grad=True)
    xV2 = Variable(x.cuda(), requires_grad=True)
    kV2 = Variable(k.cuda(), requires_grad=True)
    randGrad = torch.randn(xBSz, kOutChan, yH, yW).cuda()
else:
    xV1 = Variable(x, requires_grad=True)
    kV1 = Variable(k, requires_grad=True)
    xV2 = Variable(x, requires_grad=True)
    kV2 = Variable(k, requires_grad=True)
    randGrad = torch.randn(xBSz, kOutChan, yH, yW)

#print(dir(tnnf.conv.ConvNd))


y1 = torch.nn.functional.conv2d(xV1, kV1, padding=pad, stride=convStride)
y1.backward(randGrad)

m = exploreConv(inStride=convStride, inPad=pad, imgDim=xImgDim)
start_time = time.time()
#yEC = m(xV2, kV2).backward(randGrad)
m(xV2, kV2).backward(randGrad)
end_time = time.time()
print(y1.size())
print(yH, yW)
print((xV1.grad.data - xV2.grad.data).sum())
print((kV1.grad.data - kV2.grad.data).sum())
print("elapsed time: {}".format(end_time-start_time)+" sek, GPU: {}".format(cudaFlag))

To everyone that faces a similar problem, i.e. manipulating gradients in the backward step: I found a working solution for myself now by breaking the problem down into two steps.

import torch
import torch.nn.init as nnInit
from torch.autograd import Variable


# implementing the TRAINED TERNARY QUANTIZATION layer
# Zhu et al. ICLR 2017
# 1a)   split the method in two functions: the first method TERNARIZES
#       the weights; the second methods just calls the standard convolutions,
#       i.e. it is NOT neccessary to implement this behavior
# 1b)   implement the functionality forward and backward as
#       torch.autograd.Function, that just manipulates the gradients that are
#       passed back from the convolution layer
# 2)    wrap these 2 functions in a torch.nn.Module


# write the first autograd function -> override forward & backward:
# pass the imagedata and its gradient through (left unchanged!) and just
# implement the weight manipulations: forward -> ternarize
# backward: do the custom gradient manipulation on weight gradients, that
# were generated by the standard convolution layer following afterwards


class ternarizeWeights(torch.autograd.Function):

    def __init__(self, ternThres):
        super(ternarizeWeights, self).__init__()
        self.tThresScale = ternThres

    def forward(self, img2Conv, fpWeights, tWp, tWn):
        self.save_for_backward(img2Conv, fpWeights, tWp, tWn)
        # compute the ternarized weights with the given threshold

        quantW = fpWeights.clone()
        quantThres = self.tThresScale * quantW.abs().max()

        # normalize weights
        quantW = quantW / quantW.abs().max()
        # quantize to {-Wn, 0, Wp}
        quantW[quantW.abs() <= quantThres] = 0.0
        quantW[quantW < -1.0*quantThres] = -1.0*tWn[0]
        quantW[quantW > quantThres] = tWp[0]
        return img2Conv, quantW

    def backward(self, grad_outputImg, grad_outputW):
        img2Conv, fpWeights, tWp, tWn = self.saved_tensors
        quantThres = self.tThresScale * fpWeights.abs().max()
        # leave the image gradient unchanged
        grad_img2Conv = grad_outputImg
        # all other gradients have to be manipulated
        # first clone them, to ensure the correct datatype and dimension
        grad_fpWeights = fpWeights.clone()
        grad_tWp = tWp.clone()
        grad_tWn = tWn.clone()

        grad_tWp[0] = grad_outputW[fpWeights > quantThres].sum()
        grad_tWn[0] = grad_outputW[fpWeights < -1.0 * quantThres].sum()

        grad_fpWeights[fpWeights > quantThres] = grad_outputW[fpWeights > quantThres] * tWp[0]
        grad_fpWeights[fpWeights < -1.0 * quantThres] = grad_outputW[fpWeights < -1.0 * quantThres] * tWn[0]
        grad_fpWeights[fpWeights.abs() <= quantThres] = 1.0 * grad_outputW[fpWeights.abs() <= quantThres]
        return grad_img2Conv, grad_fpWeights, grad_tWp, grad_tWn


# write the second autograd function -> override forward
# call the standard convolution function on the unchanged image but
# with ternarized weights (2D case) and pass back their gradients
# -> the predeceding layer has to handle the gradients

class execute2DConvolution(torch.nn.Module):
    def __init__(self, inStride=1, inPadding=0, inDilation=1, inGroups=1):
        super(execute2DConvolution, self).__init__()
        self.cStride = inStride
        self.cPad = inPadding
        self.cDil = inDilation
        self.cGrp = inGroups

    def forward(self, dataIn, weightIn):
        return torch.nn.functional.conv2d(dataIn, weightIn, bias=None,
                                          stride=self.cStride, padding=self.cPad,
                                          dilation=self.cDil, groups=self.cGrp)


# wrap the two functions inside one module, so this appears as a
# customized convolution to the outside world as Conv1/2D

class ttqConv2D(torch.nn.Module):
    def __init__(self, inWCin, inWCout, inWH,
                 inStride=1, inPad=0, inDil=1, inGroups=1, inTScale=0.05):
        super(ttqConv2D, self).__init__()
        # initialize all parameters that the convolution function needs to know
        self.conStride = inStride
        self.conPad = inPad
        self.outPad = 0
        self.conDil = inDil
        self.conTrans = False
        self.conGroups = inGroups

        # initialize the weights and the bias as well as the
        self.tScale = inTScale
        self.fpWeight = torch.nn.Parameter(torch.Tensor(inWCout, inWCin, inWH, inWH))
        # xavier weight initialization
        nnInit.xavier_normal(self.fpWeight)
        self.tWeightPos = torch.nn.Parameter(torch.Tensor([0]))
        self.tWeightNeg = torch.nn.Parameter(torch.Tensor([0]))
        nnInit.uniform(self.tWeightPos, 0.8, 1.2)
        nnInit.uniform(self.tWeightNeg, 0.8, 1.2)

    def forward(self, dataInput):
        dInput, tWeights = ternarizeWeights(self.tScale)(dataInput, self.fpWeight,
                                                         self.tWeightPos, self.tWeightNeg)

        return execute2DConvolution(self.conStride, self.conPad,
                                    self.conDil, self.conGroups)(dInput, tWeights)


# just the same functionality as in the 2D case.

class execute1DConvolution(torch.nn.Module):
    def __init__(self, inStride=1, inPadding=0, inDilation=1, inGroups=1):
        super(execute1DConvolution, self).__init__()
        self.cStride = inStride
        self.cPad = inPadding
        self.cDil = inDilation
        self.cGrp = inGroups

    def forward(self, dataIn, weightIn):
        return torch.nn.functional.conv1d(dataIn, weightIn, bias=None,
                                          stride=self.cStride, padding=self.cPad,
                                          dilation=self.cDil, groups=self.cGrp)


class ttqConv1D(torch.nn.Module):
    def __init__(self, inWCin, inWCout, inL,
                 inStride=1, inPad=0, inDil=1, inGroups=1, inTScale=0.05):
        super(ttqConv1D, self).__init__()
        # initialize all parameters that the convolution function needs to know
        self.conStride = inStride
        self.conPad = inPad
        self.outPad = 0
        self.conDil = inDil
        self.conTrans = False
        self.conGroups = inGroups

        # initialize the weights and the bias as well as the
        self.tScale = inTScale
        self.fpWeight = torch.nn.Parameter(torch.Tensor(inWCout, inWCin, inL))
        # xavier weight initialization
        nnInit.xavier_normal(self.fpWeight)
        self.tWeightPos = torch.nn.Parameter(torch.Tensor([0]))
        self.tWeightNeg = torch.nn.Parameter(torch.Tensor([0]))
        nnInit.uniform(self.tWeightPos, 0.8, 1.2)
        nnInit.uniform(self.tWeightNeg, 0.8, 1.2)

    def forward(self, dataInput):
        dInput, tWeights = ternarizeWeights(self.tScale)(dataInput, self.fpWeight,
                                                         self.tWeightPos, self.tWeightNeg)
        #print(self.fpWeight)
        #print(tWeights)
        return execute1DConvolution(self.conStride, self.conPad,
                                    self.conDil, self.conGroups)(dInput, tWeights)
6 Likes

How did you pass in the weights to your module like ttqConv2D? I’m using nn.Sequential to build the model in a sequence like this:

self.convLayer = nn.Sequential(
            Conv2d(3, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128, momentum=momentum, eps=eps),
            BinaryTanh(),
            Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(2))

I would like to pass the weights into BinaryTanh, which is subclassed from nn.module like your tqConv2D.

In this case the weights are the learnable parameters for my ttqConv-Modules:

I tried to mimic the torch.nn.Conv2d interface/behaviour, i.e. the weights are “attributes” of these layer instances. In the ttqConv2d.forward() method, I pass these weights down to the nn.functional.Conv2d Function, which expects the weights explicitly as its function argument.
Maybe you can also write a wrapper module, which contains the weights as its parameters and then just calls customized functions during the forward pass? Than you could use the wrapper module in nn.Sequential as before?

Thanks for your snippet.

Why do you pass in the img2Conv to the ternerizeWeights function? Its value or gradient don’t seem to be used by the function. You might as well skip that input, right?

I think this should be working too, yes.

@magz Hi, I was researching on implementing this paper and I realized that even after ttq, the model size remains the same. I understand that you need to save the model as 2-bit values and scaling factors but struggling with implementing. Any suggestions or possible to share your snippet if you were able to do this?

Hello,

I am trying to use your ttqConv2D class in a simple test case with mnist dataset.

This is the original structure:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5, bias=False)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5, bias=False)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        # x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = F.relu(F.max_pool2d(self.conv2(x),2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Was able to get to 99% accuracy after 10 epoch.

I then tried to replace nn.Conv2D with the following:

        super(Net, self).__init__()
        self.conv1 = ttqConv2D(1,10,5)
        self.conv2 = ttqConv2D(10,20,5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

but the model won’t converge in this case and will get stuck at ~ 11%. I then tried to modify your ttqConv2d class so that it actually doesn’t manipulate the weight nor gradient, but the model still doesn’t converge. Do you know what I did wrong?

Revised ttqConv2d without the quantization

class ternarizeWeights(torch.autograd.Function):

    def __init__(self, ternThres):
        super(ternarizeWeights, self).__init__()
        self.tThresScale = ternThres

    def forward(self, img2Conv, fpWeights, tWp, tWn):
        self.save_for_backward(img2Conv, fpWeights, tWp, tWn)
        # compute the ternarized weights with the given threshold

        quantW = fpWeights.clone()
        quantThres = self.tThresScale * quantW.abs().max()

        return img2Conv, quantW

    def backward(self, grad_outputImg, grad_outputW):
        img2Conv, fpWeights, tWp, tWn = self.saved_tensors
        quantThres = self.tThresScale * fpWeights.abs().max()
        # leave the image gradient unchanged
        grad_img2Conv = grad_outputImg
        # all other gradients have to be manipulated
        # first clone them, to ensure the correct datatype and dimension
        grad_fpWeights = fpWeights.clone()
        grad_tWp = tWp.clone()
        grad_tWn = tWn.clone()

        return grad_img2Conv, grad_fpWeights, grad_tWp, grad_tWn


# write the second autograd function -> override forward
# call the standard convolution function on the unchanged image but
# with ternarized weights (2D case) and pass back their gradients
# -> the predeceding layer has to handle the gradients

class execute2DConvolution(torch.nn.Module):
    def __init__(self, inStride=1, inPadding=0, inDilation=1, inGroups=1):
        super(execute2DConvolution, self).__init__()
        self.cStride = inStride
        self.cPad = inPadding
        self.cDil = inDilation
        self.cGrp = inGroups

    def forward(self, dataIn, weightIn):
        return torch.nn.functional.conv2d(dataIn, weightIn, bias=None,
                                          stride=self.cStride, padding=self.cPad,
                                          dilation=self.cDil, groups=self.cGrp)


# wrap the two functions inside one module, so this appears as a
# customized convolution to the outside world as Conv1/2D

class ttqConv2D(torch.nn.Module):
    def __init__(self, inWCin, inWCout, inWH,
                 inStride=1, inPad=0, inDil=1, inGroups=1, inTScale=0.05):
        super(ttqConv2D, self).__init__()
        # initialize all parameters that the convolution function needs to know
        self.conStride = inStride
        self.conPad = inPad
        self.outPad = 0
        self.conDil = inDil
        self.conTrans = False
        self.conGroups = inGroups

        # initialize the weights and the bias as well as the
        self.tScale = inTScale
        self.fpWeight = torch.nn.Parameter(torch.Tensor(inWCout, inWCin, inWH, inWH))
        # xavier weight initialization
        nnInit.xavier_normal(self.fpWeight)
        self.tWeightPos = torch.nn.Parameter(torch.Tensor([0]))
        self.tWeightNeg = torch.nn.Parameter(torch.Tensor([0]))
        nnInit.uniform(self.tWeightPos, 0.8, 1.2)
        nnInit.uniform(self.tWeightNeg, 0.8, 1.2)

    def forward(self, dataInput):
        dInput, tWeights = ternarizeWeights(self.tScale)(dataInput, self.fpWeight,
                                                         self.tWeightPos, self.tWeightNeg)

        return execute2DConvolution(self.conStride, self.conPad,
                                    self.conDil, self.conGroups)(dInput, tWeights)

Thanks,

It seems this issue has been solved in more recent releases. Check out the Github issue opened by the OP: (https://github.com/pytorch/pytorch/issues/1776 )