How to retain grad when a model is called inside another model

I am working on an Actor Critic where the model inherits from 3 other models- 2 RNNs and one Policy Net. They are:

class LSTMModelPublic(nn.Module):
    def __init__(self, state_size, cnn_output_dim, input_dim, hidden_dim, layer_dim, dropout, output_dim):
        super(LSTMModelPublic, self).__init__()
        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim

        #CNN layer
        self.cnn = nn.Sequential(nn.Conv1d(state_size, cnn_output_dim, 3, stride=1, padding = 1), nn.ReLU())

        #FC layer
        self.raw_fc  = nn.Sequential(nn.Linear(input_dim, 64), nn.ReLU())

        #RNN layer_dim
        self.lstm = nn.LSTM(
            64, hidden_dim, layer_dim, batch_first=True, dropout= dropout
        )

        #FC layers
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim*2),
            nn.ReLU(),
            nn.Linear(hidden_dim*2, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )



    def forward(self, x):
        x = self.cnn(x)
        #print('after cnn', x.shape)

        x = x.squeeze(0)
        x = self.raw_fc(x)
        #print('after raw fc',x.shape)

        x = x.unsqueeze(0)

        h0 = torch.randn(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        c0 = torch.randn(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        x, _ = self.lstm(x, (h0, c0))
        #print('after lstm', x.shape)

        x = x.squeeze(0)
        out = self.fc(x)
        #print('final', out.shape)

        return out

class ANNPrivate(nn.Module):
    def __init__(self, input_dim, state_size, cnn_output_dim, hidden_dim, layer_dim, dropout):
        super(ANNPrivate, self).__init__()
        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim

        self.raw_fc = nn.Sequential(nn.Linear(input_dim, 64), nn.ReLU())

        self.cnn = nn.Sequential(nn.Conv1d(state_size, cnn_output_dim, 3, stride=1, padding = 1), nn.ReLU())

        self.lstm = nn.LSTM(
            64, hidden_dim, layer_dim, batch_first=True, dropout= dropout
        )

        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 8),
            nn.ReLU(),
            nn.Linear(8, output_dim)
        )

    def forward(self, x):
        x = x.unsqueeze(0)
        #print(x.shape)

        x = self.raw_fc(x)
        #print(x.shape)
        x = self.cnn(x)

        h0 = torch.randn(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()
        c0 = torch.randn(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        x, _ = self.lstm(x, (h0, c0))
        x = x.squeeze(0)
        out = self.fc(x)

        return out

class Policy(nn.Module):
    """
    implements both actor and critic in one model
    """
    def __init__(self):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(state_size*2, 64)

        self.fc2 = nn.Linear(64, 32)
        self.layer_out = nn.Sequential(nn.Linear(32, 1), nn.Softmax(dim=-1))


        # actor's layer
        self.action_head = nn.Linear(state_size*2, 1)
        self.mu = nn.Sigmoid()
        self.var = nn.Softplus()

        # critic's layer
        self.value_head = nn.Linear(state_size*2, 1)


    def forward(self, x):
        """
        forward of both actor and critic
        """
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.layer_out(x)
        x = torch.flatten(x)

        action_prob = self.action_head(x)
        mu = self.mu(action_prob)
        var = self.var(action_prob)

        state_values = self.value_head(x)

        return mu, var, state_values

And the model is:

class Model(torch.nn.Module):
    def __init__(self, ann_private, rnn_public, policy):
        super(Model, self).__init__()
        self.ann_private = ann_private
        self.rnn_public = rnn_public
        self.policy = policy

    def forward(self, private_input, state):
      private_input = self.ann_private(private_input)
      processed_state = self.rnn_public(state)
      self.inference_state = torch.cat([processed_state, private_input], dim=-1)

      mu, var, state_values =  self.policy(self.inference_state)
      return mu, var, state_values

I am testing the code with unittest to see if the grads are working properly. I am doing this:

class TestModel(TestCase):
    def setUp(self):
        self.private_input = torch.randn(10,2)
        self.state_input = torch.randn(1, 10, 30)
    def test_all_parameters_updated_for_specific_model(self):
        self.private_input.requires_grad = True
        self.state_input.requires_grad = True
        optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, eps=1e-8, weight_decay=1e-5, amsgrad=True)
        optimizer.zero_grad()
        output1, output2, output3 = model(self.private_input, self.state_input)
        loss = output1 + output2 + output3
        loss.backward()
        optimizer.step()
        
        for param_name, param in model.named_parameters():
            if param.requires_grad:
                with self.subTest(name=param_name):
                    self.assertIsNotNone(param.grad)
                    self.assertNotEqual(0., torch.sum(param.grad ** 2).item())

I am seeing that the last assertion is not working that is, torch.sum(param.grad**2).item() is 0.0
But, the one before it, that is self.assertIsNotNone(param.grad) is working. I don’t understand what is going on. Is it that the backpropagation is not working properly? What could be the solution to the problem?

Well I printed the torch.sum(param.grad ** 2).item() for all the params as print(param_name,torch.sum(param.grad ** 2).item()) and found that there were nonzero results upto the action_head of Policy. That means that there were some problem at self.layer_out = nn.Sequential(nn.Linear(32, 1), nn.Softmax(dim=-1)) of the Ploicy. I changed it to self.layer_out = nn.Sequential(nn.Linear(32, 1)) and the problem seems to be solved. But can anyone tell me why nn.Softmax(dim=-1) was blocking the backprop to propagate further?