Question about using another model in a customized loss function (grad None error))

Hi, I would like to use model a to predict another model b(LeNet)'s parameters, so I need to implement a loss function myself

def custom_loss_function_new(output_batch, features, targets):
    batch_loss = 0.0
    features = features.to(DEVICE)
    targets = targets.to(DEVICE)

    for output in output_batch:
        output1 = output.clone()
        test_model = LeNet5(NUM_CLASSES, GRAYSCALE).to(DEVICE)
        optimizer1 = torch.optim.Adam(test_model.parameters(), lr=LEARNING_RATE1)  
        
        unflatten_parameters(output, test_model) # override LeNet's parameters
        test_model.train()

        logits, probas = test_model(features)
        cost1 = F.cross_entropy(logits, targets)
        cost1 = cost1.requires_grad_()
        print("COST1:", cost1)

        optimizer1.zero_grad()
        cost1.backward()

        for name, param in test_model.named_parameters():
                if param.requires_grad:
                    print(name, param.grad)

        optimizer1.step()
        loss = F.mse_loss(output1, flatten_parameters(test_model))
        # print("loss1:",loss)
        batch_loss += loss
    average_loss = batch_loss / len(output_batch)
    return average_loss

above is my loss function and it works well, how ever I want to define my own backforward function in the autograd, so I define a torch.autograd.Function class

class My_class_loss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, output_batch, features, targets):
        average_loss = custom_loss_function_new(output_batch, features, targets)

        return average_loss
    
    @staticmethod
    def backward(ctx, grad_output):
        grads, = ctx.saved_tensors
        return grads, None, None

and use below code to calculate loss

        cost = custom_loss_function_new(generated_data, features, targets) #1
        cost = My_class_loss.apply(generated_data, features, targets) #2
        

when I use custom_loss_function_new, it just print the correct param.grad
but when I use My_class_loss, I just got None for grad

COST1: tensor(3.3397, requires_grad=True)
features.0.weight None
features.0.bias None
features.3.weight None
features.3.bias None
classifier.0.weight None
classifier.0.bias None
classifier.2.weight None
classifier.2.bias None
classifier.4.weight None
classifier.4.bias None

I have no idea why I get different results in a same function

Hi Broken!

Your forward() function doesn’t call ctx.save_for_backward (...),
so in your backward() function ctx.saved_tensors returns nothing
(that is, an empty tuple). Therefore your grads come out as None.

Quoting from some of pytorch’s autograd documentation:

Step 2: It is your responsibility to use the functions in ctx properly in order to ensure that the new Function works properly with the autograd engine.

Best.

K. Frank

Thanks for your answer!
I understand that I don’t have a complete implementation of the backward function in My_class_loss, which would cause me to not return the grad value correctly when I call this class to calculate the loss.
But my problem is that I’m calling LeNet for inference in the forward function and my LeNet is not working correctly (LeNet’s grad is none), and I don’t think backward can affect LeNet.
The strange thing for me is that if I call custom_loss_function_new directly, LeNet works fine, but if I call custom_loss_function_new in My_class_loss, it goes wrong, and I doubt if this has anything to do with torch.autograd.Function or @staticmethod implementation mechanism.

Hi KFrank,
Here is the final version of My_class_loss for your reference.


class My_class_loss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, output_batch, features, targets):
        batch_loss = 0.0
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
        
        grads = torch.empty((len(output_batch), len(output_batch[0])), device=DEVICE)

        for i, output in enumerate(output_batch):
            output1 = output.clone()
            test_model = LeNet5(NUM_CLASSES, GRAYSCALE).to(DEVICE)
            optimizer1 = torch.optim.Adam(test_model.parameters(), lr=LEARNING_RATE)  
            test_model.train()
            unflatten_parameters(output1, test_model)

            logits, probas = test_model(features)
            cost1 = F.cross_entropy(logits, targets)
            cost1 = cost1.requires_grad_()

            optimizer1.zero_grad()
            cost1.backward()
            batch_loss += cost1

            for name, param in test_model.named_parameters():
                if param.requires_grad:
                    print(name, param.grad)

            grads[i] = torch.cat([p.grad.flatten() for p in test_model.parameters()])


        average_loss = batch_loss / len(output_batch)

        ctx.save_for_backward(grads)

        # print(output_batch)
        # print(features)
        # print(targets)

        # average_loss = custom_loss_function_new(output_batch, features, targets)

        return average_loss
    
    @staticmethod
    def backward(ctx, grad_output):
        grads, = ctx.saved_tensors
        return grads, None, None

so the problem here is that I got None for test_model(LeNet)'s grad
the code in forward is just copied from custom_loss_function_new
and as I said, if I call custom_loss_function_new directly, everything is fine, i will get the correct grad

and below is my training code just in case you need to take a look

for epoch in range(NUM_EPOCHS):
    
    model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):
        features, targets = features.to(DEVICE), targets.to(DEVICE)
        
        z = torch.rand((BATCH_SIZE, NUM_INPUT)).to(DEVICE)
        
        ### FORWARD AND BACK PROP
        generated_data = model(z)
        
        cost = My_class_loss.apply(generated_data, features, targets)
        # cost = custom_loss_function_new(generated_data, features, targets)

        optimizer.zero_grad()
        cost.backward()

        optimizer.step()


        ### LOGGING
        # if not batch_idx % 50:
        #     print('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
        #           % (epoch+1, NUM_EPOCHS, batch_idx, len(train_loader), cost))
        print('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' 
                  % (epoch+1, NUM_EPOCHS, batch_idx, len(train_loader), cost))
        


        model.eval()
        with torch.set_grad_enabled(False): # save memory during inference
            z = torch.rand(1, NUM_INPUT).to(DEVICE)
            samples = model(z)

            test_model = LeNet5(NUM_CLASSES, GRAYSCALE).to(DEVICE)

            samples = samples.squeeze()
            unflatten_parameters(samples, test_model)

            print("accuracy:", compute_accuracy(test_model, test_loader).item())
            
        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

Hi Broken!

My apologies – I misunderstood your question.

What’s going on is that autograd doesn’t track computations inside of an
autograd.Function. It’s as if the computations are performed inside of
a with torch.no_grad(): block.

Here is a simple script that illustrates this:

import torch
print (torch.__version__)

x = torch.tensor ([1.0, 2.0], requires_grad = True)

def someFunction (x):
    print ('torch.is_grad_enabled() =', torch.is_grad_enabled())
    y = 2.0 * x
    print ('y =', y, '(inside someFunction())')   # tracked by autograd
    return  y

u = someFunction (x)
print ('u =', u)                                  # tracked by autograd

class SomeCustomFunction (torch.autograd.Function):
    @staticmethod
    def forward (ctx, x):
        z = someFunction (x)                      # no longer tracked by autograd
        return  z
    
    @staticmethod
    def backward(ctx, grad_output):
        print ('ctx.saved_tensors = ...')
        return  None

custom_fn = SomeCustomFunction()
u = custom_fn.apply (x)
print ('u =', u)                                  # use of SomeCustomFunction is tracked by autograd, but not its internals

And here is its output:

2.2.1
torch.is_grad_enabled() = True
y = tensor([2., 4.], grad_fn=<MulBackward0>) (inside someFunction())
u = tensor([2., 4.], grad_fn=<MulBackward0>)
torch.is_grad_enabled() = False
y = tensor([2., 4.]) (inside someFunction())
u = tensor([2., 4.], grad_fn=<SomeCustomFunctionBackward>)

I don’t know if there is a way to reliably reenable autograd tracking inside of a
custom Function. Is there some way you can recast your program logic so
that it doesn’t use a custom Function?

Also, your My_class_loss takes output_batch, features. and targets
as arguments, but its backward() appears to be returning the gradients
with respect to the Parameters of a temporarily-instantiated LeNet5. I see
no reason that these gradients will match up with the input arguments, so I
think a backward pass through My_class_loss will fail.

Last, depending on what you actually want to compute the gradients
with respect to, you might be wanting to backpropagate through the
optimizer1.step() call (that seems to be missing from your “final
version”).

If so, some of the discussion in the following thread might be relevant:

Best.

K. Frank

Hi KFrank
Thanks a lot for your help, I successfully solve the problem: just add with torch.enable_grad(): in the forward function. And that’s all.
:partying_face: