Multi-Threaded Backprop Failing in A3C Implementation

Greetings,

I am attempting to implement the A3C algorithm - requiring that there be a pool of agents, each undertaking a relatively distinct trajectory through the environment. This is accomplished by using the torch.multiprocessing module.

I am receiving a relatively cryptic error, requesting that I report a bug to Pytorch:

The actual error seems to consist somewhere in the backpropagation of the multi threaded loss function. The following is the function that is to be “multiplexed” - if you will. See the backprop error at the “HERE” comment:

def worker(name, input_shape, n_actions, global_agent, optimizer, env_id,
    n_threads, global_idx):
    T_max = 20
    local_agent = ActorCriticNetwork(input_shape, n_actions)
    memory = Memory()

    # input shape has # channels first but the frame buffer must have # channels last
    # swap the channels:
    frame_buffer = [input_shape[1], input_shape[2], 1]
    env = make_environment(env_id, shape = frame_buffer)

    episode, max_eps, t_steps, scores = 0, 1000, 0, []

    while episode < max_eps:
        obs = env.reset()
        score, done, ep_steps = 0, False, 0
        hidden_state = T.zeros(1, 256)
        while not done:

            obs = T.tensor([obs], dtype = T.float)# obs = T.tensor(np.array(obs), dtype = T.float) # CURRENT

            action, value, log_prob, hidden_state = local_agent(obs, hidden_state)
            next_obs, reward, done, info = env.step(action)
            memory.store_transition(reward, value, log_prob)
            score += reward
            ep_steps += 1
            t_steps += 1
            obs = next_obs

            if ep_steps % T_max == 0 or done:
                rewards, values, log_probs = memory.sample_memory()
                loss = local_agent.calc_cost(
                    obs, hidden_state, done, rewards, values, log_probs
                )
                optimizer.zero_grad()
                hidden_state = hidden_state.detach_()
                # hidden_state = hidden_state.detach() # CURRENT


                loss.backward() ### ------------------------ HERE ------------------------ ###


                # loss.sum().backward() # CURRENT
                T.nn.utils.clip_grad_norm_(local_agent.parameters(), 40)

                for local_param, global_param in zip(
                                                local_agent.parameters(),
                                                global_agent.parameters()):
                    global_param._grad = local_param.grad

                optimizer.step()
                local_agent.load_state_dict(global_agent.state_dict())
                memory.clear_memory()

        episode += 1

        with global_idx.get_lock():
            global_idx.value += 1

        if name == '1':
            scores.append(score)
            avg_score = np.mean(scores[-100:])
            print(f'A3C episode: {episode}, thread: {name} of {n_threads}, steps: {t_steps/1e6}, score: {score}, avg_score: {avg_score}')

    if name == '1':
        x = [z for z in range(episode)]
        plot_learning_curve(x, scores, 'A3C_pong_final.png')

I’ve scoured around for a while and found nothing particularly helpful. I’m happy to elaborate and to provide additional material/context. Any assistance will be greatly appreciated.

Many thanks,

Merlin :wink:

Could you check if you are seeing the same error in the latest nightly binary?
If so, would it be possible for you to post a minimal, executable code snippet to reproduce the issue?

Thanks for the reply.

I assume I just install the latest nightly PyTorch version into a virtual environment and try again?
I’ll try this anyway.

Thanks again.

Hmmm,

So I attempted this via running the latest nightly binary for the cpu version of PyTorch - in a virtual environment. Here is the exact version I downloaded:

I appear to have got the exact same error as with my base install:

I’m not sure how to post a “minimal” code snippet for reproducibility, there are several dependencies for this project. Perhaps I could make a github repo for you to examine?

Many thanks!

Merlin

Actually, it seems I have managed to resolve the issue.

I don’t know why I didn’t try this sooner - since the error said: “No grad_fn for non-leaf saved tensor”. I just went in to the function that computes the loss and added “requires_grad=True” to the actual loss tensor, prior to computing loss.backward(). This seems to have fixed my issue, see below for the fix - at the “HERE” comment:

Many thanks to anyone who took the time to read my original query.
I welcome any suggestions for improvements.

Thanks,

Merlin

Your fix could avoid running into the error, but is detaching the computation graph so your model would most likely not learn anything assuming total_loss is used to compute the gradients.

Yeah this is exactly what happened.

total_loss is used in precisely that way, and the model learned nothing.
Why would specifying that the gradient is required, in any way detach it? Is there a simple work-around from here?

Thanks again.

I think your log_probs are sampled from the replay buffer aren’t they? They should be differentiable, which is almost surely not the case if they come from the RB. Have you considered re-computing them by passing through the policy network when calculating the loss? I think this may be the reason why the backward does not produce gradients for the policy weights. As for the critic weights they should be ok since you call self.forward (unless done is True, I would make sure that done is mostly False in your task).

If you’re interested I think torchrl provides most of the primitives you’re looking for here, we’d be glad to provide support if you’d like to use them.

You are recreating a new tensor which would detach it from the current computation graph and creates a new leaf-tensor. The right approach would be to use the output tensors directly without rewrapping them.

@vmoens’s suggestion sounds valid and torchrl might be a good starter for your use case.

Thanks vmoens,

Yes, my log_probs are sampled from a replay buffer; they consist in a python list of pytorch tensors. When it comes time to compute the loss I cast this list to a pytorch tensor with torch.cat:

log_probs = T.cat(log_probs)

After inspecting the gradient function assigned to this subsequent tensor, I get this:

log_probs.grad_fn: <CatBackward0 object at 0x7f1a49496880>

So at leas there is a gradient function for Backward to use when computing the loss. Assuming that this function is differentiable, I can’t see what the problem might be with backward.

I soon realised that there was no need to cast total_loss to another tensor in which requires_grad is true; since this is the default behaviour anyway (whoops). Thus - I think - I won’t run into the detaching issue that @ptrblck mentioned.

Here is what my loss function looks like now:

Thanks for the suggestion to check out torchrl. I have not done so as of yet (this project has had to hit the back burner for the last week) but I endeavour to have a look.

I’m not sure what you mean exactly, regarding your suggestion to recalculate the log_probs when it comes time to use them in the loss calculation. I can’t yet see how this doesn’t defeat the purpose of a replay buffer. can you elaborate on this?

Many thanks to anyone reading this.

Usually log_probs is not sampled from the replay buffer unless it’s used to correct for the change in model configuration (see PPO and related)
I suspect that if you sample data from the replay buffer you are in an off policy setting. As such, the data is not generated according to the current policy and it wouldn’t make sense to backpropagate through a graph that is now obsolete.
Hence you should sample what is needed to compute the log prob from the replay buffer and backpropagate through a loss computed using a log_prob computed from the data sampled from the replay buffer.

Now of course this is all based on common assumptions about your model and replay buffers. If you’re storing data that is part of a graph in the replay buffer, you’re probably not using it right (usually replay buffers are like datasets: it’s fixed, graph-free data that is purely used to train a model in an off policy manner).

Let me know if I got anything wrong. Still happy to help if I do!