DQN cartpole agent from pytorch's tutorial not learning

I tried to implement the pytorch’s DQN tutorial https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html as an agent class for tidyness and in order to implement n-step learning. I wrote the following code rearranging the tutorial:

class Agent:
  def __init__(self, exp_buffer, policy_net, target_net, device):
    self.exp_buffer = exp_buffer
    self.policy_net = policy_net
    self.target_net = target_net
    self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=1e-4, amsgrad=True)
    self.device = device
    self.steps_done = 0
    self.eps_start = 0.9
    self.eps_end = 0.05
    self.eps_decay = 1000
    self.gamma = 0.99
    self.tau = 0.005

  def _eps_decay(self):
    return self.eps_end + (self.eps_start - self.eps_end) * \
        math.exp(-1. * self.steps_done / self.eps_decay)

  def _select_action(self, state):

    eps_threshold = self._eps_decay()

    if random.random() > eps_threshold:
      with torch.no_grad():
        # .view(1,1) in order for it to work with the .gather() method
        # used to extract state action values
        return self.policy_net(state).max(1)[1].view(1,1)
      return torch.tensor([[env.action_space.sample()]], device=self.device, dtype=torch.long)

  def sync_nets(self):
    target_net_state_dict = self.target_net.state_dict()
    policy_net_state_dict = self.policy_net.state_dict()

    for key in policy_net_state_dict:
      target_net_state_dict[key] = policy_net_state_dict[key]*self.tau + target_net_state_dict[key]*(1-self.tau)


  def explore(self, env, state, n_steps=1):
    state_init = state
    state_init = torch.tensor(state_init, dtype=torch.float32, device=self.device).unsqueeze(0)

    total_reward = 0.0
    for step in range(n_steps):
      state = torch.tensor(state, dtype=torch.float32, device=self.device).unsqueeze(0)
      action = self._select_action(state)
      observation, reward, terminated, truncated, _ = env.step(action.item())

      if terminated:
        next_state = None
        next_state = torch.tensor(observation, dtype=torch.float32, device=self.device).unsqueeze(0)

      reward *= self.gamma**step

      total_reward += reward
      state = next_state

    done = terminated or truncated
    self.steps_done += 1
    total_reward = torch.tensor([total_reward], device=self.device)
    self.exp_buffer.push(state_init, action, next_state, total_reward)

    return done

  def optimize_model(self, batch_size, n_steps=1):

    if len(self.exp_buffer) < batch_size:

    transitions = self.exp_buffer.sample(batch_size)
    batch = Transition(*zip(*transitions))
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=self.device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])

    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    state_action_values = self.policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0]

    expected_state_action_values = (next_state_values * self.gamma**n_steps) + reward_batch

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    torch.nn.utils.clip_grad_value_(self.policy_net.parameters(), 100)

episode_durations = []
seed = 1423
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state)
policy_net = DQN(n_observations, n_actions).to(device)
target_net = copy.deepcopy(policy_net)

exp_buffer = ReplayMemory(10000)

if torch.cuda.is_available:
    num_episodes = 600
    num_episodes = 50

agent = Agent(exp_buffer, policy_net, target_net, device)

for i_episode in range(num_episodes):
  state, info = env.reset()

  for t in count():
    done = agent.explore(env, state)

    if done:
      episode_durations.append(t + 1)


I omitted imports and the definitions of the DQN, buffer and plot function for brevity as they are the same as in the tutorial.

In the .explore() method is implemented the n step dqn but in this case is used as simple 1-step and it should be the same as in the tutorial.

When i run the code the reward is stuck around 10-15 and it doesn’t grow (when obviously running the code as in the tutorial works well)