I am not sure where is my computational graph breaking here since the weights are not updating at all. Any help/suggestion to debug this will be highly appreciated. Thanks in advance. FYI: This is a simpler version of my original code but eventually results in the same - computational graph breaking. And I suspect it has to do with the VehicleDynamicsModel class.
import torch
import torch.nn as nn
class PacejkaTireModel(nn.Module):
def __init__(self, B, C, D, E):
super(PacejkaTireModel, self).__init__()
self.B = torch.tensor(B, requires_grad=True)
self.C = torch.tensor(C, requires_grad=True)
self.D = torch.tensor(D, requires_grad=True)
self.E = torch.tensor(E, requires_grad=True)
def forward(self, x):
# Magic Formula
y = self.D * torch.sin(self.C * torch.atan(self.B * x - self.E * (self.B * x - torch.atan(self.B * x))))
return y
class VehicleDynamicsModel(nn.Module):
def __init__(self, dt, wheelbase, mass, I_z):
super(VehicleDynamicsModel, self).__init__()
self.dt = dt
self.wheelbase = wheelbase
self.mass = mass
self.I_z = I_z
self.pacejka_lat = PacejkaTireModel(B=10.0, C=1.9, D=1.0, E=-1.2)
self.pacejka_lon = PacejkaTireModel(B=10.0, C=1.9, D=1.0, E=-1.2)
def forward(self, state, accel, steer):
# Calculate slip angles
slip_angle_front = torch.atan2(state[4] + state[5] * self.wheelbase / 2, state[3]) - steer
slip_angle_rear = torch.atan2(state[4] - state[5] * self.wheelbase / 2, state[3])
# Calculate tire forces
F_yf = self.pacejka_lat(slip_angle_front)
F_yr = self.pacejka_lat(slip_angle_rear)
F_xr = self.pacejka_lon(accel)
# Calculate dynamics
a_x = (F_xr - F_yf * torch.sin(steer)) / self.mass
a_y = (F_yf * torch.cos(steer) + F_yr) / self.mass
a_r = (F_yf * self.wheelbase / 2 * torch.cos(steer) - F_yr * self.wheelbase / 2) / self.I_z
state_dot = torch.zeros(6)
# Update state
state_dot[0] = state[3] * torch.cos(state[2]) - state[4] * torch.sin(state[2])
state_dot[1] = state[3] * torch.sin(state[2]) + state[4] * torch.cos(state[2])
state_dot[2] = state[5]
state_dot[3] = a_x
state_dot[4] = a_y
state_dot[5] = a_r
return state + state_dot*self.dt
import torch.nn as nn
import torch.optim as optim
class ControlNetwork(nn.Module):
def __init__(self):
super(ControlNetwork, self).__init__()
self.fc1 = nn.Linear(6, 32)
self.fc2 = nn.Linear(32, 2)
def forward(self, state):
x = torch.relu(self.fc1(state))
control = self.fc2(x)
return control
loss_fn = nn.MSELoss()
# Instantiate the neural network, vehicle dynamics model, and optimizer
control_net = ControlNetwork()
vehicle_dynamics = VehicleDynamicsModel(dt=0.1, wheelbase=2.5, mass=1500, I_z=3000)
optimizer = optim.Adam(control_net.parameters(), lr=0.001)
# Create a toy dataset
state = torch.tensor([0.0, 0.0, 0.0, 10.0, 0.0, 0.0], dtype=torch.float32)
target_position = torch.tensor([10.0, 10.0], dtype=torch.float32)
for epoch in range(20):
# Train the neural network for one epoch
optimizer.zero_grad()
# Predict the control inputs (steering and acceleration)
control = control_net(state)
accel, steer = control[0], control[1]
# Simulate the vehicle dynamics
next_state = vehicle_dynamics(state, accel, steer)
# Calculate the loss using mean squared error with the target position
loss = loss_fn(next_state[0:2], target_position)
# Backpropagate the loss through the network
loss.backward()
optimizer.step()
norm_p = 0
for p in control_net.parameters():
norm_p += torch.norm(p)
print(f"Loss: {loss.item()}, Norm of weights: {norm_p.item()}")