Differential programming for RL

Hello! I am trying to implement something similar to this (the CartPole example), using Pytorch. Basically I want to build a NN that predicts an action given the state of the system, computes the next state given the action, then backpropagates through everything. They have a code implemented in Julia here and here is the beginning of my code for Pytorch:

import gym
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

env = gym.make('CartPole-v0')

class NN_cart(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_len, 24),               
            nn.ReLU(True),
            nn.Linear(24, 48),
            nn.ReLU(True),
            nn.Linear(48, 1),
            nn.Tanh(),
        )
            
    def forward(self, x):
        x = self.net(x)
        x = (torch.sign(x)+1)/2
        return x

model = NN_cart().cuda()

optimizer = optim.Adam(model.parameters(), lr = 1e-3)

done = False
env.reset()
state = torch.from_numpy(env.state).float().cuda()
while not done:
    model.train()
    optimizer.zero_grad()
    
    action = int(model(state)[0].cpu().data.numpy())
    state, reward, done, info = env.step(action)
    state = torch.from_numpy(state).float().cuda()
    state.requires_grad = True
    
    loss = state[2]**2
    
    loss.backward()
    optimizer.step()

I get no error, but the NN doesn’t learn. So one thing is the gradient of the sign function. In the Julia implementation they define their own gradient but I am not sure how to do it. Also, is my code the way it is right now able to propagate through everything (including the .step() function)? Thank you!

If I’m not mistaken, the gradient of the sign method should be zero everywhere.
Since the authors implemented a custom backward method, you could have a look at this tutorial to see, how to do the same in PyTorch.

I’m not exactly sure, what env.step() is doing, but generally if you leave PyTorch and use e.g. numpy, you will detach the operations from the computation graph, which seems to be the case here.