How to properly create a batch with torch.Tensor

Hello!

I have a repo where I have implemented A2C and PPO and it uses the same code to gather a batch of data:

    states = np.zeros((args.num_steps, args.num_envs) + obversation_shape)
    actions = np.zeros((args.num_steps, args.num_envs) + action_shape)
    rewards = np.zeros((args.num_steps, args.num_envs))
    flags = np.zeros((args.num_steps, args.num_envs))
    state_values = torch.zeros((args.num_steps, args.num_envs)).to(args.device)
    log_probs = torch.zeros((args.num_steps, args.num_envs)).to(args.device)

    num_updates = int(args.total_timesteps // args.num_steps)
    global_step = 0

    state, _ = envs.reset(seed=SEED)
    next_done = np.zeros(args.num_envs)

    for _ in tqdm(range(num_updates)):
        # log_probs, state_values = [], []
        start = time.perf_counter()

        for i in range(args.num_steps):
            global_step += 1
            flags[i] = next_done

            action, log_prob, state_value = agent.get_action(state)

            next_state, reward, terminated, truncated, infos = envs.step(
                action)

            states[i] = state
            actions[i] = action
            rewards[i] = reward
            log_probs[i] = log_prob
            state_values[i] = state_value
            # state_values.append(state_value)
            # log_probs.append(log_prob)

            state = next_state
            next_done = np.logical_or(terminated, truncated)

            if "final_info" not in infos:
                continue

            for info in infos["final_info"]:
                if info is None:
                    continue

                writer.add_scalar("rollout/episodic_return",
                                  info["episode"]["r"], global_step)
                writer.add_scalar("rollout/episodic_length",
                                  info["episode"]["l"], global_step)

        agent.update_policy(
            {
                "states": torch.from_numpy(states).float().to(args.device),
                "actions": torch.from_numpy(actions).float().to(args.device),
                "last_state": torch.from_numpy(next_state).float().to(
                    args.device),
                "last_flag": torch.from_numpy(next_done).float().to(
                    args.device),
                "rewards": torch.from_numpy(rewards).float().to(args.device),
                "flags": torch.from_numpy(flags).float().to(args.device),
                # "state_values": torch.stack(state_values).squeeze(),
                # "log_probs": torch.stack(log_probs).squeeze()
                "state_values": state_values,
                "log_probs": log_probs
            }, global_step)

The code runs fines with PPO, but with A2C, I have this bug message:

Traceback (most recent call last):
File “/var/home/valentin/Work/rl-gym/main.py”, line 225, in
main()
File “/var/home/valentin/Work/rl-gym/main.py”, line 201, in main
agent.update_policy(
File “/var/home/valentin/Work/rl-gym/rl_gym/agent.py”, line 54, in update_policy
self.algorithm.update_policy(batch, step)
File “/var/home/valentin/Work/rl-gym/rl_gym/algorithm/a2c.py”, line 93, in update_policy
loss.backward()
File “/var/home/valentin/Work/rl-gym/.venv/lib64/python3.10/site-packages/torch/_tensor.py”, line 488, in backward
torch.autograd.backward(
File “/var/home/valentin/Work/rl-gym/.venv/lib64/python3.10/site-packages/torch/autograd/init.py”, line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to backward through the graph a second time or if

I identified where does it comes from and, it is actually the way I’m saving the log_probs during the rollout. When I apply the commented instead (stocking into a list), everything works perfectly.
Same happens with the critic values (stored as state_values), when I put them in Tensor it have a segmentation fault error, in a list it works fine.

At the end I fixed the bug by storing log_probs and state_values into lists, but I don’t understand why I have to do this. It has been difficult to debug, and I still don’t really get why it works that way and not the other way (storing into fixed tensors, which is more efficient).

Hey @valentin-cnt
I think you should collect data using the torch.no_grad and run your loss independently. For PPO and A2C the loss should not backprop gradients through the log-probability collected during inference.
Here’s a couple of implementations in torchrl for you to get a sense of how it’s done

This is a WIP tutorial on PPO but the code works fine

Here’s the PPO loss in torchrl. It’s executed on data that do not carry any gradients so to say (more precisely, they’re not part of any differentiable computational graph)

Hope that helps!

Hey @vmoens
Thank you for your answer. I actually do use torch.no_grad for PPO, but when I use it for A2C, it doesn’t learn anything at all. I guess maybe my implentation is wrong, but since it worked out the way I coded it, I don’t know what’s wrong.

Here’s the code:

I think you need to compute the log prob in your loss, not in the data batch.
Check torchrl’s A2C code

Check the way we get the log-prob: it’s done at learning time, not during inference.

Hope that helps

Yes, it works. Thank you for answer, it actually makes to keep sense to seperate the data batch and the update data. I understand it better now, thank you very much

1 Like