Multiple gradient for backward return and receive

class DTT2B(DTL2R):
    def __init__(self):
        super(DTT2B, self).__init__()

    def forward(self, weight, x):
        self.dim = x.size()
        self.intermediate = torch.zeros(self.dim)
        self.weight = weight
        h = x.clone()
        for i in range(1, self.dim[2]):
            self.intermediate[:, :, i, :] = h[:, :, i-1, :] - h[:, :, i, :]
            h[:, :, i, :] += weight[i, :]*self.intermediate[:, :, i, :]
        return h

    def backward(self, grad_weight, grad_output):
        print '2'*200
        intermediate, weight = self.intermediate, self.weight
        grad_weight, grad_input = grad_weight.clone(), grad_output.clone()
        print grad_output.size()
        for i in range(self.dim[2]-1, 0, -1):
            grad_weight[i, :] += (grad_input[:, :, i+1, :] *
                intermediate[:, :, :, i]).sum(dim=0).sum(dim=0)
            grad_input[:, :, i-1, :] = weight[i, :] * grad_input[:, :, i, :]
            grad_input[:, :, i, :] = (1-weight[i, :]) * grad_input[:, :, i, :]
        print self.__class__.__name__, grad_input, grad_weight
        return grad_input, grad_weight

class DTB2T(DTL2R):
    def __init__(self):
        super(DTB2T, self).__init__()

    def forward(self, weight, x):
        self.dim = x.size()
        self.intermediate = torch.zeros(self.dim)
        self.weight = weight
        h = x.clone()
        for i in range(self.dim[2]-2, -1, -1):
            self.intermediate[:, :, i, :] = h[:, :, i+1, :] - h[:, :, i, :]
            h[:, :, i, :] += weight[i+1, :]*self.intermediate[:, :, i, :]
        return h

    def backward(self, grad_weight, grad_output):
        print '3'*200
        intermediate, weight = self.intermediate, self.weight
        grad_weight, grad_input = grad_weight.clone(), grad_output.clone()
        for i in range(self.dim[2]-1):
            grad_weight[i+1, :] += (grad_input[:, :, i, :] *
                intermediate[:, :, i, :]).sum(dim=0).sum(dim=0)
            grad_input[:, :, i+1, :] = weight[i+1, :] * grad_input[:, :, i, :]
            grad_input[:, :, i, :] = (1-weight[i+1, :]) * grad_input[:, :, i, :]
        print self.__class__.__name__, grad_input, grad_weight
        print grad_input.size(), grad_weight.size()
        return grad_input, grad_weight
DT1 = DTT2B()
DT2 = DTB2T()
x=...
h=DT1(x)
h=DT2(h)
loss=softmax2d_loss()
loss.backward()# here would be error

Could any one can help me? I want to propagate multiple gradient to previous layer.
TypeError: backward() takes exactly 3 arguments (2 given)

Hi,

Your function takes 2 input (w and x) and has 1 output (o).
So the backward function will take as input the gradients wrt to o: do and should return gradients for all the inputs: dw and dx. Your backward function should not expect a grad_weight as input.