Writing custom autograd.Function

I am trying to write a custom autograd.Function that solves logistic regression using LBFGS. While solve_logistic_regression(Xv,yv,.1) works flawlessly, lr(Xv,yv,.1) throws the following runtime error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I’m not sure what is going on here.

The code:

import numpy as np

from torch.autograd import Function, Variable
import torch

N, n = 100, 2
X = np.random.randn(N, n)
y = np.random.randint(0,2,size=N)
Xv = Variable(torch.Tensor(X), requires_grad=False)
yv = Variable(torch.Tensor(y), requires_grad=False)

def solve_logistic_regression(X, y, lamb):
    N, n = X.shape
    theta = Variable(torch.ones(n), requires_grad=True)
    optimizer = torch.optim.LBFGS([theta], lr=.8)
    def closure():
        pi = 1./(1.+torch.exp(-X.mm(theta.unsqueeze(-1))))
        loss = 1./N*torch.nn.BCELoss()(pi.squeeze(), y) + lamb/2*torch.norm(theta[:-1])**2
        print (loss.item())
        return loss
    return theta
class LogisticRegression(Function):
    def forward(ctx, X, y, lamb):
        theta = solve_logistic_regression(X, y, lamb)
        return theta
    def backward(ctx, grad_output):
        return None, None, None
lr = LogisticRegression.apply

I have an identical issue: a function which uses autograd works when called directly, but throws the “element 0 of tensors…” error when called from the forward pass of a custom autograd.Function. The code worked in version 0.3 but no longer works in 0.4. Is there any update on this?

I find my custom Function set the output’s requires_grad to False too. Do you solve it?

For the original poster’s question of how to use backward within a function, the solution is to wrap the calculation in with torch.enable_grad(): as in

  with torch.enable_grad():
            theta = solve_logistic_regression(X, y, lamb)

(and, of course, return something reasonable).
You can think of the forward of an autograd.Function to be inside a with torch.no_grad(): block, so you need to reenable tracking.

Best regards


P.S.: As a side note, when posting your question in multiple venues (which you should only in very rare cases), it is a great service to those answering and following along if you add references to your cross posting. I thought I had seen the exact question before, and no, it’s not a bug.