Using autograd in customized autograd Function


I am trying to define customized modules (Linear, Convolutional, etc) where I would like to have all sorts of regularizations on parameters. I provided a snippet below to clarify my question.
I don’t want to calculate gradients of those regularization loss functions manually, so I am using autograd. The problem is that although the parameters of the local net I define within the backward function have requires_grad=True the output does not have a grad_fn. So, It seems to me that autograd doesn’t work inside the customized autograd function. How am I supposed to use autograd in this case?

Thank you,

class LinearCustomFunc(autograd.Function):
    def forward(context, input, weight, weight_cu, bias=None):
        context.save_for_backward(input, weight, weight_cu, bias)
        output =
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    def backward(context, grad_output):
        input, weight, weight_fa, bias = context.saved_tensors
        grad_input = grad_weight = grad_weight_cu = grad_bias = None

        if context.needs_input_grad[0]:
        if context.needs_input_grad[1]:
        if context.needs_input_grad[2]:         
            local_net = nn.Sequential(nn.Linear(2,3))
            output = local_net(input)
            loss = nn.MSELoss()
            grad_weight_cu = grads['0.wieght']
        if bias is not None and context.needs_input_grad[3]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_weight_cu, grad_bias


If you’re not running with creat_graph=True, then the backward runs with autograd disabled.
If you need to enable it locally you can use with torch.enable_grad():. Also I would advise you to use autograd.grad here to get the gradient more easily:

            local_net = nn.Sequential(nn.Linear(2,3)) # This will contain random weights. Is it expected?
            with torch.enable_grad():
                output = local_net(input)
                loss = nn.MSELoss()
            grad_weight_cu = autograd.grad(loss, local_net[0].weight)[0]

1 Like