Train a Neural Network and using it as a loss function

I am training a model and also want to use it as a loss function too. See schematic below:

pred = self.forward(im1, im2)

# calculate loss
with torch.no_grad():
     pred_new = self.forward(im3, pred)
loss = F.mse_loss(pred, pred_new)

When I do a loss.backward(), I get an RunTime Error as the grad_fn is missing.

What would be the ideal way to calculate the loss in this scenario?

Your use case should generally work as the target usually doesn’t require a gradient as seen here:

model = nn.Linear(10, 2)

x = torch.randn(1, 10)
pred = model(x)

with torch.no_grad():
    pred_new = model(torch.randn(1, 10))
    
loss = F.mse_loss(pred, pred_new)
loss.backward()
print(model.weight.grad)

With that being said, I don’t know how your model is defined etc. and which part of the code is raising the error.

The difference between the example you show is that, in my second call to the model the prediction (model output) from the first call is used as an input without using detach .

Could you post a minimal, executable code snippet which would reproduce the issue, please?

This is exactly what I am doing and trying to achieve. You should be able to reproduce the RuntimeError.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.inp1 = nn.Linear(10, 2)
        self.inp2 = nn.Linear(10,2)
        self.out = nn.Linear(4, 10)

    def forward(self, im1, im2):
        im1 = self.inp1(im1)
        im2 = self.inp2(im2)
        return self.out(torch.concat([im1, im2], axis=1))

model = Model()
inp1 = torch.randn(1, 10)
inp2 = torch.randn(1, 10)
gt = torch.randn(1, 10)

pred = model(inp1, inp2)

with torch.no_grad():
    pred_new = model(pred, inp1)
    
loss = F.mse_loss(gt, pred_new)
loss.backward()
print(model.weight.grad)

The error is

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Thanks for the code! The error is expected as none of the tensors is attached to a computation graph:

loss = F.mse_loss(gt, pred_new)

gt is the target tensor created directly via gt = torch.randn(1, 10) while pred_new is created in the no_grad() context so also not attached to a computation graph.
Could you explain which gradients the backward() call should calculate?

I have grad_fn for the operations from inp1 and inp2 to the output pred. I want the gradients to computed for that part of the code but with respect to the loss calculated with the mse_loss.

so the parameters of the model can be updated with the step function of the optimizer.

I feel one way I could achieve this is by writing the code like this. Could you verify the effect of the deepcopy from copy? What happens when the model is loaded on the GPU?

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
import copy

class Model(nn.Module):
    def __init__(self, grayscale=False):
        super().__init__()

        self.inp1 = nn.Linear(10, 2)
        self.inp2 = nn.Linear(10,2)
        self.out = nn.Linear(4, 10)

    def forward(self, im1, im2):
        im1 = self.inp1(im1)
        im2 = self.inp2(im2)
        return self.out(torch.concat([im1, im2], axis=1))

model = Model()
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_model = copy.deepcopy(model)
inp1 = torch.randn(1, 10)
inp2 = torch.randn(1, 10)
gt = torch.randn(1, 10)

pred = model(inp1, inp2)
loss_model = copy.deepcopy(model)
pred_new = loss_model(pred, inp1)
loss = F.mse_loss(inp1, pred_new)
loss.backward()
optimizer.step()

Also the memory footprint of the training is super high due to the copy of the model with all its parameters. Is their an elegant method to prevent the copy of the model to achieve the same results?

I think in terms of saving on some memory at forward after copying I set the requires_grad= False.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
import copy

class Model(nn.Module):
    def __init__(self, grayscale=False):
        super().__init__()

        self.inp1 = nn.Linear(10, 2)
        self.inp2 = nn.Linear(10,2)
        self.out = nn.Linear(4, 10)

    def forward(self, im1, im2):
        im1 = self.inp1(im1)
        im2 = self.inp2(im2)
        return self.out(torch.concat([im1, im2], axis=1))

model = Model()
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)

# copy the og model remove gradient calculation
loss_model = copy.deepcopy(model)
for p in loss_model.parameters():
        p.requires_grad = False

inp1 = torch.randn(1, 10)
inp2 = torch.randn(1, 10)
gt = torch.randn(1, 10)

pred = model(inp1, inp2)
pred_new = loss_model(pred, inp1)
loss = F.mse_loss(inp1, pred_new)
loss.backward()
optimizer.step()