DQN cartpole agent from pytorch's tutorial not learning

I tried to implement the pytorch’s DQN tutorial https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html as an agent class for tidyness and in order to implement n-step learning. I wrote the following code rearranging the tutorial:

class Agent:
  def __init__(self, exp_buffer, policy_net, target_net, device):
    self.exp_buffer = exp_buffer
    self.policy_net = policy_net
    self.target_net = target_net
    self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=1e-4, amsgrad=True)
    self.device = device
    self.steps_done = 0
    self.eps_start = 0.9
    self.eps_end = 0.05
    self.eps_decay = 1000
    self.gamma = 0.99
    self.tau = 0.005


  def _eps_decay(self):
    return self.eps_end + (self.eps_start - self.eps_end) * \
        math.exp(-1. * self.steps_done / self.eps_decay)


  def _select_action(self, state):

    eps_threshold = self._eps_decay()

    if random.random() > eps_threshold:
      with torch.no_grad():
        # .view(1,1) in order for it to work with the .gather() method
        # used to extract state action values
        return self.policy_net(state).max(1)[1].view(1,1)
    else:
      return torch.tensor([[env.action_space.sample()]], device=self.device, dtype=torch.long)


  def sync_nets(self):
    target_net_state_dict = self.target_net.state_dict()
    policy_net_state_dict = self.policy_net.state_dict()

    for key in policy_net_state_dict:
      target_net_state_dict[key] = policy_net_state_dict[key]*self.tau + target_net_state_dict[key]*(1-self.tau)

    self.target_net.load_state_dict(target_net_state_dict)


  def explore(self, env, state, n_steps=1):
    
    state_init = state
    state_init = torch.tensor(state_init, dtype=torch.float32, device=self.device).unsqueeze(0)

    total_reward = 0.0
    
    for step in range(n_steps):
      state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
      action = self._select_action(state)
      observation, reward, terminated, truncated, _ = env.step(action.item())

      if terminated:
        next_state = None
      else:
        next_state = torch.tensor(observation, dtype=torch.float32, device=self.device).unsqueeze(0)

      reward *= self.gamma**step

      total_reward += reward
      state = next_state

    done = terminated or truncated
    self.steps_done += 1
    total_reward = torch.tensor([total_reward], device=self.device)
    self.exp_buffer.push(state_init, action, next_state, total_reward)

    return done


  def optimize_model(self, batch_size, n_steps=1):

    if len(self.exp_buffer) < batch_size:
      return

    transitions = self.exp_buffer.sample(batch_size)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=self.device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])


    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    state_action_values = self.policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]

    expected_state_action_values = (next_state_values * self.gamma**n_steps) + reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    self.optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)
    self.optimizer.step()

episode_durations = []
seed = 1423
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state)
torch.manual_seed(seed)
policy_net = DQN(n_observations, n_actions).to(device)
target_net = copy.deepcopy(policy_net)
target_net.to(device)

exp_buffer = ReplayMemory(10000)

if torch.cuda.is_available:
    num_episodes = 600
else:
    num_episodes = 50

agent = Agent(exp_buffer, policy_net, target_net, device)

for i_episode in range(num_episodes):
  state, info = env.reset()

  for t in count():
    
    done = agent.explore(env, state)
    agent.optimize_model(batch_size=128)
    agent.sync_nets()

    if done:
      episode_durations.append(t + 1)
      plot_durations()
      break

print('Complete')
plot_durations(show_result=True)
plt.plot(episode_durations)
plt.show()

I omitted imports and the definitions of the DQN, buffer and plot function for brevity as they are the same as in the tutorial.

In the .explore() method is implemented the n step dqn but in this case is used as simple 1-step and it should be the same as in the tutorial.

When i run the code the reward is stuck around 10-15 and it doesn’t grow (when obviously running the code as in the tutorial works well)