Autograd derivatives of multioutput ANN

I have a simple ANN with input_dim=1 and output_dim=2.

class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(1, 128)
        self.fc  = nn.Linear(128, 128)
        self.fc2 = nn.Linear(128, 2)
        
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc(x))
        x = self.fc2(x)
        return x

I am studying the PINNs idea, and I want to calculate the ODE loss. Given that I have a system of 2 ODEs, how should I calculate the derivatives of the 2-outputs ANN? Currently, I am doing:

def physics_loss(model: nn.Module):
    t = np.linspace(0, 2, 250)
    t_ts = torch.tensor(t, dtype=torch.float32, requires_grad=True).reshape(-1, 1)
    y_pred_ts = model(t_ts)
    y_pred_1_ts = y_pred_ts[:, 0].view(-1, 1)
    y_pred_2_ts = y_pred_ts[:, 1].view(-1, 1)
    
    dy_pred_1_ts_dt = torch.autograd.grad(y_pred_1_ts, t_ts, torch.ones_like(y_pred_1_ts), create_graph=True)[0]
    dy_pred_2_ts_dt = torch.autograd.grad(y_pred_2_ts, t_ts, torch.ones_like(y_pred_2_ts), create_graph=True)[0]
    
    error_1 = dy_pred_1_ts_dt + 5 * y_pred_1_ts
    error_2 = dy_pred_2_ts_dt - 5 * y_pred_1_ts + y_pred_2_ts

    return torch.mean(error_1**2) + torch.mean(error_2**2)

Is this correct?

Your code isn’t buggy and runs fine:

model = NN()
loss = physics_loss(model)
print(loss) # tensor(0.2380, grad_fn=<AddBackward0>)

But I don’t quite see why you’d calculate the grads of y_pred_1_ts which is the model output wrt t_ts which is the model input (the data) rather than the model parameters. Assuming you wanted to show a minimum executable snippet and ask if torch.autograd.grad can be used the way you showed, I think you’re gtg. Also see .backward().

I want to calculate the physics loss term of a PINN. This is why I calculate the derivatives of the model output wrt to the model input.