Gradients are none for the actor after calling loss.backward

Basically, I am trying to use a Graph Convolution network in conjunction with td3 to create a drone guidance scheme. The problem I am having is that even though both my critics will calculate gradients after calling backwards, the actor will not. I have tried using backwards hooks to find an issue, but I cannot get an output from any of the layers. Can anyone see why the critic would work but the actor won’t?

Here is the code for reproduction:

import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.utils as nn_utils
import torch_geometric
from torch.optim import Adam, SGD
from torch_geometric.nn import GCNConv, GATv2Conv, global_mean_pool
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_undirected, grid
from torch_geometric.transforms import ToSparseTensor
from torchviz import make_dot
import csv
import os
import random
import numpy as np
import psutil

def backward_hook(module, grad_input, grad_output):
print(“Actor Network Backward Hook:”)
print(“Gradient Input:”, grad_input)
print(“Gradient Output:”, grad_output)

def depth_image_to_graph_batch(depth_images):
# Convert to PyTorch tensor
depth_images = depth_images.clone().detach().requires_grad_(True).to(‘cuda’)

batch_size, channels, height, width = depth_images.size()
flattened_image = depth_images.view(batch_size, -1)
# Create a grid of node indices corresponding to pixel positions
pos_x, pos_y = torch.meshgrid(torch.arange(height), torch.arange(width))
pos_x = pos_x.reshape(-1).float().to('cuda')
pos_y = pos_y.reshape(-1).float().to('cuda')

# Normalize pixel coordinates
pos_x = (pos_x / (height - 1)) * 2 - 1
pos_y = (pos_y / (width - 1)) * 2 - 1
data_list = []
for i in range(batch_size):
# Stack coordinates to form node features
    node_features = torch.stack([pos_x, pos_y, flattened_image[i]], dim=-1)

# Create edge indices for a grid graph
    edge_index = grid(height, width, device='cuda', dtype=torch.long)

# Create a PyTorch Geometric Data object
    data = Data(x=node_features, edge_index=edge_index, device='cuda').to('cuda')
    data_list.append(data) 
batch_data = Batch.from_data_list(data_list).to('cuda')
return batch_data

class ActorModel(nn.Module):
def init(self, graph_input_dim, secondary_input_dim, action_dim, action_scaling=2.5, num_heads=1):
super(ActorModel, self).init()
self.conv1 = GCNConv(graph_input_dim, 128)
self.conv2 = GCNConv(128, 64)
self.conv3 = GATv2Conv(64, 64, heads=num_heads)
self.dense1 = nn.Linear(64, 256)
self.dense2 = nn.Linear(256, 128)
self.dense3 = nn.Linear(4, 64)
self.dense4 = nn.Linear(192, 64)
self.dense5 = nn.Linear(64, action_dim)
self.action_scaling = action_scaling
self.global_pool = global_mean_pool

def forward(self, data, secondary_input):
    # Convert to PyTorch Geometric Data object
    data = depth_image_to_graph_batch(data)

    # Secondary input
    secondary_in = torch.tensor(secondary_input, dtype=torch.float32).clone().detach().requires_grad_(True).to('cuda')

    # Apply the first GCNConv layer
    h = F.elu(self.conv1(data.x, data.edge_index[0]))

    # Apply the second GCNConv layer
    h = F.elu(self.conv2(h, data.edge_index[0]))

    # Apply attention
    atten = F.elu(self.conv3(h, data.edge_index[0]))
    h = h + atten

    # Global pooling
    h = self.global_pool(h, data.batch)

    ## Apply the dense layers
    h = F.elu(self.dense1(h))
    h = F.elu(self.dense2(h))
    m = (self.dense3(secondary_in))

    # Concatenate the features
    h = torch.cat([h, m], dim=1)
    h = F.elu(self.dense4(h))

    # Apply the output layer and scale the action
    action = F.tanh(self.dense5(h))
    scaled_action = action * self.action_scaling
    return scaled_action

Define the critic network for the TD3 algorithm

class CriticModel(nn.Module):
def init(self, graph_input_dim, secondary_input_dim, action_dim, num_heads=1):
super(CriticModel, self).init()
self.conv1 = GCNConv(graph_input_dim, 128)
self.conv2 = GCNConv(128, 64)
self.conv3 = GATv2Conv(64, 64, heads=num_heads)
self.dense_state = nn.Linear(64, 256)
self.dense3 = nn.Linear(4, 64)
self.dense4 = nn.Linear(256, 64)
self.dense5 = nn.Linear(256, 128)
self.dense_action = nn.Linear(action_dim, 64)
self.dense_combined = nn.Linear(64, 1)
self.global_pool = global_mean_pool

def forward(self, data, secondary_input, action):
    # Convert to PyTorch Geometric Data object
    data = depth_image_to_graph_batch(data)

    # Secondary input
    secondary_in = torch.tensor(secondary_input, dtype=torch.float32).clone().detach().requires_grad_(True).to('cuda')
    action_in = torch.tensor(action, dtype=torch.float32).clone().detach().requires_grad_(True).to('cuda')

    # Apply the first GCNConv layer
    h = F.elu(self.conv1(data.x, data.edge_index[0]))

    # Apply the second GCNConv layer
    h = F.elu(self.conv2(h, data.edge_index[0]))

    # Apply attention
    atten = F.elu(self.conv3(h, data.edge_index[0]))
    h = h + atten

    # Global pooling
    h_state = self.global_pool(h, data.batch)

    # Apply the dense layers for state
    h_state = F.elu(self.dense_state(h_state))
    h_state = F.elu(self.dense5(h_state))

    # Secondary input
    m = F.elu(self.dense3(secondary_in))

    # Action input
    h_action = F.elu(self.dense_action(action_in))

    # Combine features
    h_combined = torch.cat([h_state, m, h_action], dim=1)

    # Apply the dense layers
    x = F.elu(self.dense4(h_combined))

    # Output value
    value = self.dense_combined(x)

    return value

class TD3Agent:
def init(self, state_dim, action_dim, hidden_units, buffer_size, batch_size, discount, tau, actor_lr, critic_lr, policy_noise, noise_clip, policy_delay):
self.state_dim = state_dim
self.action_dim = action_dim
self.hidden_units = hidden_units
self.buffer_size = buffer_size
self.batch_size = batch_size
self.discount = discount
self.tau = tau
self.actor_lr = actor_lr
self.critic_lr = critic_lr
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_delay = policy_delay

    self.actor = ActorModel(graph_input_dim=3, secondary_input_dim=4, action_dim=2).to('cuda')
    self.critic1 = CriticModel(graph_input_dim=3, secondary_input_dim=4, action_dim=2).to('cuda')
    self.critic2 = CriticModel(graph_input_dim=3, secondary_input_dim=4, action_dim=2).to('cuda')
    self.target_actor = ActorModel(graph_input_dim=3, secondary_input_dim=4, action_dim=2).to('cuda')
    self.target_critic1 = CriticModel(graph_input_dim=3, secondary_input_dim=4, action_dim=2).to('cuda')
    self.target_critic2 = CriticModel(graph_input_dim=3, secondary_input_dim=4, action_dim=2).to('cuda')
  
    self.critic2.dense5.register_full_backward_hook(backward_hook)


    self.target_actor.load_state_dict(self.actor.state_dict())
    self.target_critic1.load_state_dict(self.critic1.state_dict())
    self.target_critic2.load_state_dict(self.critic2.state_dict())

    self.actor_optimizer = Adam(self.actor.parameters(), lr=actor_lr, weight_decay=1e-2)
    self.critic1_optimizer = Adam(self.critic1.parameters(), lr=critic_lr, weight_decay=1e-2)
    self.critic2_optimizer = Adam(self.critic2.parameters(), lr=critic_lr, weight_decay=1e-2)
    



def train(self):
    depth_images = torch.rand((32,1,16,128)).to('cuda')
    secondary_input = torch.rand((32, 4)).to('cuda') 
    depth_images2 = torch.rand((32,1,16,128)).to('cuda')
    secondary_input2 = torch.rand((32, 4)).to('cuda') 
    actions = torch.rand((32, 2)).to('cuda') 
    states = [depth_images, secondary_input, actions]
    next_states = [depth_images2, secondary_input2, actions]
    
    rewards = np.random.randn(32,1)
    dones = np.random.randint(0,1,(32,1))

    # Update target networks
    for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):
        target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)

    for target_param, param in zip(self.target_critic1.parameters(), self.critic1.parameters()):
        target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)

    for target_param, param in zip(self.target_critic2.parameters(), self.critic2.parameters()):
        target_param.data.copy_(self.tau * param.data + (1.0 - self.tau) * target_param.data)

    # Update critics
    with torch.autograd.detect_anomaly():
        target_actions = self.target_actor(next_states[0], next_states[1])
        target_actions += torch.clamp(torch.normal(0, self.policy_noise, target_actions.shape), -self.noise_clip, self.noise_clip).to('cuda')
        target_actions = torch.clamp(target_actions, -2.5, 2.5).to('cuda')
        next_state_target_values1 = self.target_critic1(next_states[0], next_states[1], target_actions)
        next_state_target_values2 = self.target_critic2(next_states[0], next_states[1], target_actions)
        next_state_target_value = torch.min(next_state_target_values1, next_state_target_values2).to('cpu')
        next_state_target_value = next_state_target_value.detach().numpy()
        q_values1 = self.critic1(states[0], states[1], actions)
        q_values2 = self.critic2(states[0], states[1], actions)
        q_targets = rewards.reshape(32,1) + self.discount * next_state_target_value * dones.reshape(32,1)
        q_targets = torch.tensor(q_targets, dtype=torch.float32).to('cuda')
        critic1_loss = F.smooth_l1_loss(q_values1, q_targets)
        critic2_loss = F.smooth_l1_loss(q_values2, q_targets)
        
    self.critic1_optimizer.zero_grad()
    self.critic2_optimizer.zero_grad()
    
    critic1_loss.backward(retain_graph=True)
    critic2_loss.backward()

    self.critic1_optimizer.step()
    self.critic2_optimizer.step()
    # policy update
    
        
    actions_pred = self.actor(states[0], states[1])
    q_values11 = self.critic1(states[0], states[1], actions_pred)
    actor_loss = -torch.mean(q_values11)
    for param in self.actor.parameters():
        param.requires_grad = True
    self.actor_optimizer.zero_grad()
    actor_loss.backward()
    self.actor_optimizer.step()

state_dim = [[1,1,16,128],[2,1]]
action_dim = 2
hidden_units = 256
buffer_size = 50000
batch_size = 32
discount = 0.99
tau = 0.005
actor_lr = 1e-3
critic_lr = 1e-3
policy_noise = 0.5
noise_clip = 2.5
policy_delay = 2
b_timestep = 200
cnt = 32
learn = 32
agent = TD3Agent(state_dim, action_dim, hidden_units, buffer_size, batch_size, discount, tau, actor_lr, critic_lr,
policy_noise, noise_clip, policy_delay)

agent.train()

I found the issue: the lines converting the inputs of the models to tensors detached them from the computation graph. So deleting them fixed the issue and I just have to be sure to convert them to tensors before putting them through the actor and critics.

hi @TomH ,i have met the same problem with you .it can be found in here.https://discuss.pytorch.org/t/loss-not-converge-in-ddpg/191523/7?u=yxz77777
could you please give some advice to me ?
or can you show me about more details on your solution?