Call self-defined function twice - autograd not working

Hi,

As I mentioned in the last post gradient checking :

A new function SolveTriangular is defined:

class SolveTrianguler(Function):
    # Aware of btrisolve, btrifact, and more will come
    # sloves A * x = b
    def __init__(self, lower=True):
        super(SolveTrianguler, self).__init__()
        # lower=False, use data contained in the upper triangular, the default is lower
        self.lower = lower
        self.needs_input_grad = (True, False)

    def forward(self, matrix, rhs):
        x = torch.from_numpy(
            solve_triangular(matrix.numpy(), rhs.numpy(),
                             trans=self.trans,  lower=self.lower))
        self.save_for_backward(matrix, x)
        return x

    def backward(self, grad_output):
        # grad_matrix = grad_rhs = None
        matrix, x = self.saved_tensors
        # formula from Giles 2008, 2.3.1
        if self.lower == True:
            return torch.tril(-matrix.inverse().t().mm(grad_output).mm(torch.t(x))), None
        else:
            return torch.triu(-matrix.inverse().t().mm(grad_output).mm(torch.t(x))), None

When I called this function just once, the gradients after backward are the same with the ones from TensorFlow. In the forward function, torch.Tensor is converted into numpy array, then converted back after calling the scipy.linalg.solve_triangular. All happens on the RHS of the assignment (o.w. it breaks, don’t know why)

The problem is when I called it twice in a row, it gave me wrong answer.

    init_K = np.cov(np.random.rand(3, 6))
    # TF
    L = tf.Variable(np.tril(init_K))
    d = tf.placeholder(tf.float64, shape=(3, 1))
    alpha = tf.matrix_triangular_solve(L, d, lower=True)
    alpha2 = tf.matrix_triangular_solve(tf.transpose(L), alpha, lower=False)

    y = tf.reduce_sum(alpha2)
    grads = tf.gradients(y, [L, alpha, alpha2])

    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        grads = sess.run(grads, feed_dict={d: [[1.],[2],[3]]})
        print(grads)

    # pytorch
    L = Variable(th.from_numpy(np.tril(init_K)), requires_grad=True)
    d = Variable(th.from_numpy(np.array([[1.],[2],[3]])), requires_grad=False)
    alpha = SolveTrianguler(lower=True)(L, d)
    # alpha.register_hook(print_grad) - not working here!
    alpha2 = SolveTrianguler(lower=False)(L.t(), alpha)
    # alpha2.register_hook(print_grad)
    y = th.sum(alpha2)

    y.backward()
    print(L.grad)
    print(alpha.grad)
    print(alpha2.grad)

Calling SolveTriangular twice gave the wrong gradients (I assume TensorFlow is correct) and alpha.grad, alpha2.grad are None.

Why is it?

autograd functions are not meant to be used twice. You create a new function for every forward+backward call.

Could you elaborate more in terms of my code? What should I do if I need to SolveTriagular twice within one forward computation?

# pytorch
    L = Variable(th.from_numpy(np.tril(init_K)), requires_grad=True)
    d = Variable(th.from_numpy(np.array([[1.],[2],[3]])), requires_grad=False)
    alpha = SolveTrianguler(lower=True)(L, d)
    # alpha.register_hook(print_grad) - not working here!
    alpha2 = SolveTrianguler(lower=False)(L.t(), alpha)
    # alpha2.register_hook(print_grad)
    y = th.sum(alpha2)
    y.backward()
    

yes, you call it as many times as you invoke the function.
Each of the SolveTrianguler functions will be a node in the autograd graph, so sharing them will end up overriding the saved_tensors of the first one etc…

Your snippet now looks correct.

Thank for your reply, but the gradients got from PyTorch and TensorFlow do not agree, and alpha.grad, alpha2.grad are None. Don’t know whether it has anything to do with the implementation of the function - conversion between Tensor and np.ndarray within forward. Looking forward to the update on more matrix functions…

Thanks for the replies so far.
Will there be a way in the future to call a custom autograd.function several times and each of these function call in the graph will know which saved_tensors belong to which call?
This would make it easier to write custom activation functions, as in the example given here: http://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-defining-new-autograd-functions and reuse them several times as it is possible with F.relu etc.
Of course the function would just depend on the Input without internal parameters.

Hello @magz,

Note that what you see as torch.nn.relu is not a function instance, but rather a “factory” similar to what is done with Variables when you use somevar = Variable(...) ; b = somevar.someop(...).
If you look at the source code of Variable, you see what is done internally make Variable.someop work (which you can do manually as well).
The way the Function class then works is that you record on in the forward and compute the gradient in the backward at the point specified by the inputs of the forward. This used to be done in objects of the class but has been seperated to contexts for the new-style autograd that will allow higher order derivatives.

Best regards

Thomas

Thanks for the quick reply! I’ll have a look!