Hi @Lei_Shi1,
The reason why your gradients are coming up to be None is because you are creating new instances for all the three classes in the forward
method of combined_model
.
To be able to update the parameters of model
which is an instance of combined_model
, you could do something like this:
(Although there could be cleaner ways to do this without creating 3 separate classes)
from calendar import EPOCH
import torch
from torch import nn, tensor
import os
LEARNING_RATE = 1e-5
EPOCHS = int(1e7)
BATCH_SIZE = 1
horizon = 8
WEIGHT_DECAY = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class action_input_model(nn.Module):
def __init__(self):
super().__init__()
self.action_dense1 = nn.Linear(2, 16) # verified input (8,2), output(8,16) when nn.Linear(2,16)
self.action_relu3 = nn.ReLU()
self.action_dense2 = nn.Linear(16, 16)
class rnn_cell(nn.Module):
def __init__(self):
super().__init__()
self.rnn_cell = nn.LSTM(16, 64, 8, batch_first = True) # (input_size, hidden_size/num_units, num_layers)
class output_model_1(nn.Module):
def __init__(self):
super().__init__()
self.output_dense1 = nn.Linear(64, 32) # hidden layer features are 64
self.output_relu3 = nn.ReLU()
self.output_dense2 = nn.Linear(32, 4) # 4 is the output dimension, actually 8*4
class combined_model(action_input_model, rnn_cell, output_model_1):
def __init__(self):
super().__init__()
def forward(self, x):
x = self.action_dense1(x)
x = self.action_relu3(x)
x = self.action_dense2(x)
x, _ = self.rnn_cell(x)
x = self.output_dense1(x)
x = self.output_relu3(x)
x = self.output_dense2(x)
return x
action_input_data = torch.randn(BATCH_SIZE, horizon, 2, device=device)
model = combined_model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
for step in range(EPOCHS):
model_output = model(action_input_data)
ground_truth_position = torch.randn(1, horizon, 3, device=device)
ground_truth_collision = torch.randn(1, horizon, device=device)
loss_mse = nn.MSELoss(reduction='mean')
loss_position = loss_mse(model_output[:, :, :2], ground_truth_position[:,:,:2])
loss_position.retain_grad()
loss_cross_entropy = nn.CrossEntropyLoss(reduction='sum')
loss_collision = loss_cross_entropy(model_output[:, :, 3], ground_truth_collision)
if loss_collision != 0:
print('loss_collision', loss_collision)
loss = loss_position + loss_collision
loss.retain_grad()
optimizer.zero_grad()
loss.backward()
print('loss grad is', loss.grad)
for name, p in model.named_parameters():
print(name, p.grad.shape)
optimizer.step()
break
this works and gives:
loss_collision tensor(-4.4471, grad_fn=<NegBackward0>)
loss grad is tensor(1.)
output_dense1.weight torch.Size([32, 64])
output_dense1.bias torch.Size([32])
output_dense2.weight torch.Size([4, 32])
output_dense2.bias torch.Size([4])
rnn_cell.weight_ih_l0 torch.Size([256, 16])
rnn_cell.weight_hh_l0 torch.Size([256, 64])
rnn_cell.bias_ih_l0 torch.Size([256])
rnn_cell.bias_hh_l0 torch.Size([256])
rnn_cell.weight_ih_l1 torch.Size([256, 64])
rnn_cell.weight_hh_l1 torch.Size([256, 64])
rnn_cell.bias_ih_l1 torch.Size([256])
rnn_cell.bias_hh_l1 torch.Size([256])
rnn_cell.weight_ih_l2 torch.Size([256, 64])
rnn_cell.weight_hh_l2 torch.Size([256, 64])
rnn_cell.bias_ih_l2 torch.Size([256])
rnn_cell.bias_hh_l2 torch.Size([256])
rnn_cell.weight_ih_l3 torch.Size([256, 64])
rnn_cell.weight_hh_l3 torch.Size([256, 64])
rnn_cell.bias_ih_l3 torch.Size([256])
rnn_cell.bias_hh_l3 torch.Size([256])
rnn_cell.weight_ih_l4 torch.Size([256, 64])
rnn_cell.weight_hh_l4 torch.Size([256, 64])
rnn_cell.bias_ih_l4 torch.Size([256])
rnn_cell.bias_hh_l4 torch.Size([256])
rnn_cell.weight_ih_l5 torch.Size([256, 64])
rnn_cell.weight_hh_l5 torch.Size([256, 64])
rnn_cell.bias_ih_l5 torch.Size([256])
rnn_cell.bias_hh_l5 torch.Size([256])
rnn_cell.weight_ih_l6 torch.Size([256, 64])
rnn_cell.weight_hh_l6 torch.Size([256, 64])
rnn_cell.bias_ih_l6 torch.Size([256])
rnn_cell.bias_hh_l6 torch.Size([256])
rnn_cell.weight_ih_l7 torch.Size([256, 64])
rnn_cell.weight_hh_l7 torch.Size([256, 64])
rnn_cell.bias_ih_l7 torch.Size([256])
rnn_cell.bias_hh_l7 torch.Size([256])
action_dense1.weight torch.Size([16, 2])
action_dense1.bias torch.Size([16])
action_dense2.weight torch.Size([16, 16])
action_dense2.bias torch.Size([16])
Hope this helps,
S