Computational graph breaking but cannot locate the reason

I am not sure where is my computational graph breaking here since the weights are not updating at all. Any help/suggestion to debug this will be highly appreciated. Thanks in advance. FYI: This is a simpler version of my original code but eventually results in the same - computational graph breaking. And I suspect it has to do with the VehicleDynamicsModel class.

import torch
import torch.nn as nn

class PacejkaTireModel(nn.Module):
    def __init__(self, B, C, D, E):
        super(PacejkaTireModel, self).__init__()
        self.B = torch.tensor(B, requires_grad=True)
        self.C = torch.tensor(C, requires_grad=True)
        self.D = torch.tensor(D, requires_grad=True)
        self.E = torch.tensor(E, requires_grad=True)

    def forward(self, x):
        # Magic Formula
        y = self.D * torch.sin(self.C * torch.atan(self.B * x - self.E * (self.B * x - torch.atan(self.B * x))))
        return y

class VehicleDynamicsModel(nn.Module):
    def __init__(self, dt, wheelbase, mass, I_z):
        super(VehicleDynamicsModel, self).__init__()
        self.dt = dt
        self.wheelbase = wheelbase
        self.mass = mass
        self.I_z = I_z
        self.pacejka_lat = PacejkaTireModel(B=10.0, C=1.9, D=1.0, E=-1.2)
        self.pacejka_lon = PacejkaTireModel(B=10.0, C=1.9, D=1.0, E=-1.2)

    def forward(self, state, accel, steer):

        # Calculate slip angles
        slip_angle_front = torch.atan2(state[4] + state[5] * self.wheelbase / 2, state[3]) - steer
        slip_angle_rear = torch.atan2(state[4] - state[5] * self.wheelbase / 2, state[3])

        # Calculate tire forces
        F_yf = self.pacejka_lat(slip_angle_front)
        F_yr = self.pacejka_lat(slip_angle_rear)
        F_xr = self.pacejka_lon(accel)

        # Calculate dynamics
        a_x = (F_xr - F_yf * torch.sin(steer)) / self.mass
        a_y = (F_yf * torch.cos(steer) + F_yr) / self.mass
        a_r = (F_yf * self.wheelbase / 2 * torch.cos(steer) - F_yr * self.wheelbase / 2) / self.I_z

        state_dot = torch.zeros(6)
        # Update state
        state_dot[0] = state[3] * torch.cos(state[2]) - state[4] * torch.sin(state[2])
        state_dot[1] = state[3] * torch.sin(state[2]) + state[4] * torch.cos(state[2])
        state_dot[2] = state[5]
        state_dot[3] = a_x
        state_dot[4] = a_y
        state_dot[5] = a_r

        return state + state_dot*self.dt

import torch.nn as nn
import torch.optim as optim

class ControlNetwork(nn.Module):
    def __init__(self):
        super(ControlNetwork, self).__init__()
        self.fc1 = nn.Linear(6, 32)
        self.fc2 = nn.Linear(32, 2)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        control = self.fc2(x)
        return control

loss_fn = nn.MSELoss()
# Instantiate the neural network, vehicle dynamics model, and optimizer
control_net = ControlNetwork()
vehicle_dynamics = VehicleDynamicsModel(dt=0.1, wheelbase=2.5, mass=1500, I_z=3000)
optimizer = optim.Adam(control_net.parameters(), lr=0.001)

# Create a toy dataset
state = torch.tensor([0.0, 0.0, 0.0, 10.0, 0.0, 0.0], dtype=torch.float32)
target_position = torch.tensor([10.0, 10.0], dtype=torch.float32)

for epoch in range(20):
    # Train the neural network for one epoch
    optimizer.zero_grad()
    
    # Predict the control inputs (steering and acceleration)
    control = control_net(state)
    accel, steer = control[0], control[1]
    
    # Simulate the vehicle dynamics
    next_state = vehicle_dynamics(state, accel, steer)
    
    # Calculate the loss using mean squared error with the target position
    loss = loss_fn(next_state[0:2], target_position)
    
    # Backpropagate the loss through the network
    loss.backward()
    optimizer.step()
    
    norm_p = 0
    for p in control_net.parameters():
        norm_p += torch.norm(p)
    
    print(f"Loss: {loss.item()}, Norm of weights: {norm_p.item()}")

If you suspect your autograd graph is not being built during the forward pass, you can print out whether the intermediate tensors require grad.

Also you can insert autograd backward hooks (see Autograd mechanics — PyTorch 2.0 documentation) to observe where the gradients aren’t being back propagated. Is is possible that the derivative wrt the parameters you are observing is zero?

Adding the following

print('Output of the ControlNetwork:')
print(f"Grad: {control.data.grad}")
print(f"Requires grad: {control.requires_grad}")
print('\n')
print('Output of the VehicleDynamicsModel:')
print(f"Grad: {next_state.data.grad}")
print(f"Requires grad: {next_state.requires_grad}")
print('\n')
print('Output of the Loss Function:')
print(f"Grad: {loss.data.grad}")
print(f"Requires grad: {loss.requires_grad}")

at the end of the for loop for training and putting a break after the first epoch produces this:

Output of the ControlNetwork:
Grad: None
Requires grad: True


Output of the VehicleDynamicsModel:
Grad: None
Requires grad: True


Output of the Loss Function:
Grad: None
Requires grad: True

Is this what you wanted to see?

And I am really not sure how to integrate the autograd backward hooks. Thanks for your time.

You could do:

def hook(gradient):
    print(gradient)

control.register_hook(hook)

Why do you check the .grad of the .data field? You should be check the .grad of the tensor directly.
Also you shouldn’t be checking .grad of intermediates anyway, only the .grad of leaf tensors will be populated.

The hook function outputs the following:

tensor([0., 0.])

And I just modified to take the grad of tensor directly and also checked if it is a leaf

print('Output of the ControlNetwork:')
print(f"Grad: {control.grad}")
print(f"Requires grad: {control.requires_grad}")
print(f"Is leaf: {control.is_leaf}")

and this is what I am getting, i.e., control is not a leaf

Output of the ControlNetwork:
Grad: None
Requires grad: True
Is leaf: False