Question about using another model in a customized loss function (grad None error))

Hi Broken!

My apologies – I misunderstood your question.

What’s going on is that autograd doesn’t track computations inside of an
autograd.Function. It’s as if the computations are performed inside of
a with torch.no_grad(): block.

Here is a simple script that illustrates this:

import torch
print (torch.__version__)

x = torch.tensor ([1.0, 2.0], requires_grad = True)

def someFunction (x):
    print ('torch.is_grad_enabled() =', torch.is_grad_enabled())
    y = 2.0 * x
    print ('y =', y, '(inside someFunction())')   # tracked by autograd
    return  y

u = someFunction (x)
print ('u =', u)                                  # tracked by autograd

class SomeCustomFunction (torch.autograd.Function):
    @staticmethod
    def forward (ctx, x):
        z = someFunction (x)                      # no longer tracked by autograd
        return  z
    
    @staticmethod
    def backward(ctx, grad_output):
        print ('ctx.saved_tensors = ...')
        return  None

custom_fn = SomeCustomFunction()
u = custom_fn.apply (x)
print ('u =', u)                                  # use of SomeCustomFunction is tracked by autograd, but not its internals

And here is its output:

2.2.1
torch.is_grad_enabled() = True
y = tensor([2., 4.], grad_fn=<MulBackward0>) (inside someFunction())
u = tensor([2., 4.], grad_fn=<MulBackward0>)
torch.is_grad_enabled() = False
y = tensor([2., 4.]) (inside someFunction())
u = tensor([2., 4.], grad_fn=<SomeCustomFunctionBackward>)

I don’t know if there is a way to reliably reenable autograd tracking inside of a
custom Function. Is there some way you can recast your program logic so
that it doesn’t use a custom Function?

Also, your My_class_loss takes output_batch, features. and targets
as arguments, but its backward() appears to be returning the gradients
with respect to the Parameters of a temporarily-instantiated LeNet5. I see
no reason that these gradients will match up with the input arguments, so I
think a backward pass through My_class_loss will fail.

Last, depending on what you actually want to compute the gradients
with respect to, you might be wanting to backpropagate through the
optimizer1.step() call (that seems to be missing from your “final
version”).

If so, some of the discussion in the following thread might be relevant:

Best.

K. Frank