Call backward on function inside a backpropagation step

After looking deeper into the code, I built a workaround for the convolution function, since I need to tune the gradients of the weights. I think it’s a little messy, maybe someone wants to share some thoughts regarding the following (working!) snippet:

import torch
from torch.autograd import Variable
import torch.nn._functions as tnnf
import time


class exploreConv(torch.autograd.Function):
    def __init__(self, inStride=1, inPad=0, inDil=1, inGroups=1, imgDim=2):
        super(exploreConv, self).__init__()
        self.conStride = ()
        self.conPad = ()
        self.conDil = ()
        self.conGroups = inGroups
        for k in range(imgDim):
            self.conStride = self.conStride + (inStride,)
            self.conPad = self.conPad + (inPad,)
            self.conDil = self.conDil + (inDil,)

        self.convFct = tnnf.conv.ConvNd(self.conStride, self.conPad, self.conDil,
                                        False, (0, 0), self.conGroups)

    def forward(self, inImg, inKernel, inBias=None):
        self.save_for_backward(inImg, inKernel, inBias)
        self.convFct.requires_grad = True
        return self.convFct.forward(inImg, inKernel, inBias)

    def backward(self, grad_output):
        inImg, inKernel, inBias = self.saved_tensors
        if inBias != None:
            self.convFct.needs_input_grad = (True, True, True)
        else:
            self.convFct.needs_input_grad = (True, True, False)

        # for surveillance purpose -> time wait
        time.sleep(0)
        gradIn = self.convFct._grad_input(inImg, inKernel, grad_output)
        gradWeight, gradBias = self.convFct._grad_params(inImg, inKernel, inBias, grad_output)
        return gradIn, gradWeight, gradBias


xBSz, xChan, xH, xW = 50, 2, 512, 512
kOutChan, kInChan, kH, kW = 1, 2, 5, 5
pad = 0
convStride = 2
xImgDim = 2
x = torch.randn(xBSz, xChan, xH, xW)
k = torch.randn(kOutChan, kInChan, kH, kW)
yH = (xH - kH + 2*pad)/convStride + 1
yW = (xW - kW + 2*pad)/convStride + 1

cudaFlag = True
if cudaFlag:
    xV1 = Variable(x.cuda(), requires_grad=True)
    kV1 = Variable(k.cuda(), requires_grad=True)
    xV2 = Variable(x.cuda(), requires_grad=True)
    kV2 = Variable(k.cuda(), requires_grad=True)
    randGrad = torch.randn(xBSz, kOutChan, yH, yW).cuda()
else:
    xV1 = Variable(x, requires_grad=True)
    kV1 = Variable(k, requires_grad=True)
    xV2 = Variable(x, requires_grad=True)
    kV2 = Variable(k, requires_grad=True)
    randGrad = torch.randn(xBSz, kOutChan, yH, yW)

#print(dir(tnnf.conv.ConvNd))


y1 = torch.nn.functional.conv2d(xV1, kV1, padding=pad, stride=convStride)
y1.backward(randGrad)

m = exploreConv(inStride=convStride, inPad=pad, imgDim=xImgDim)
start_time = time.time()
#yEC = m(xV2, kV2).backward(randGrad)
m(xV2, kV2).backward(randGrad)
end_time = time.time()
print(y1.size())
print(yH, yW)
print((xV1.grad.data - xV2.grad.data).sum())
print((kV1.grad.data - kV2.grad.data).sum())
print("elapsed time: {}".format(end_time-start_time)+" sek, GPU: {}".format(cudaFlag))