Updating multiple networks simultaneously on CPU

Currently working on moving an implementation of a couple popular reinforcement learning algorithms from TensorFlow to PyTorch, and the PyTorch code is noticeably slower (up to 50%). We believe it’s because some algorithms, such as Soft-Actor Critic, have multiple semi-independent neural networks (e.g. Policy, Value, Q Function) that during the update step need to be evaluated, their losses computed, and backpropagation performed on them. In TensorFlow these could all be computed with a single session.run call, and they would be parallelized across multiple cores. However, in PyTorch we are limited to evaluating and updating each network sequentially.

Note that our networks are quite small and we don’t expect much benefit in running on GPU. However, the asynchronous execution in CUDA seems to alleviate a lot of the performance gap between TensorFlow, but at the end of the day we still need to support CPU.

Was wondering what the “right” way to do this for PyTorch on CPU. I’ve played around with using Python threading to evaluate the networks (PyTorch does release the GIL, right?) as well as changing the num_threads and KMP_BLOCKTIME settings, with some success, but am still at the 50% performance gap. Guidance would be much appreciated, thanks!

Show your implementation.

Here is the pseudo-code of the model update (soft actor-critic):

# Sample observations, dones, rewards from experience replay buffer
observations, next_observations, dones, rewards = sample_batch()

# Evaluate current policy on sampled observations
) = policy.sample_actions(observations)

# Evaluate Q networks on observations and actions
q1p_out, q2p_out = value_network(observations, sampled_actions)
q1_out, q2_out = value_network(observations, actions)

# Evaluate target network on next observations
with torch.no_grad():
    target_values = target_network(next_observations)

# Evaluate losses
q1_loss, q2_loss = sac_q_loss(q1_out, q2_out, target_values, dones, rewards)
value_loss = sac_value_loss(log_probs, sampled_values, q1p_out, q2p_out)
policy_loss = sac_policy_loss(log_probs, q1p_out)
entropy_loss = sac_entropy_loss(log_probs)

total_value_loss = q1_loss + q2_loss + value_loss

# Backprop and weights update



I’m looking for a way to parallelize the value_network, sample_actions, and target_network evaluations, as well as the three backward() calls, as these components consume the most amount of time. I’ve tried running each in different Python threads, but probably a combination of the GIL and the context-switching overhead seems to negate any possible speedup.

If your model is small, you can consider about using torch.jit._fork and torch.jit._wait, they can work around the GIL problem, but also requires to convert your models to JitScriptModule.

While pytorch does release GIL, it could still affect model performance if your model consists of many small linear layers like 100(in)->100(out).

The jit._fork seems promising, thanks! It wasn’t clear to me from the documentation, but does the fork and wait have to be called within a JitScriptModule or can it be called on functions from JitScriptModules in regular Python code?

Compile a model to JitScriptModule like this:

jit_actor = t.jit.script(acc)

and call it like this:

future = [t.jit._fork(model, *args_list) for model, args_list in zip(models, args_lists)]

Then asynchronously wait for results:

result = [t.jit._wait(fut) for fut in future]

These APIs work outside of your JIT module.

1 Like