Tanhnormal + affine transformation giving NaNs

I am training a reinforcement learning task with PPO, and have parameterized the policy to output a normal distribution, which is passed into a tanh and affine transform. Sampling from the normal distribution is supposed to give me rotation angles from -3.14 to 3.14 (thus the need of tanh transform and affine to constrain and scale the gaussian samples).

            gaussian_means, gaussian_logstddevs = logits.chunk(2, dim=-1)
            transforms = [torch.distributions.TanhTransform(cache_size=1), torch.distributions.AffineTransform(loc=0, scale=math.pi, cache_size=1)]
            base_gaussian = D.Normal(loc=gaussian_means.squeeze(dim=-1), scale=torch.exp(gaussian_logstddevs.squeeze(dim=-1)))
            return torch.distributions.TransformedDistribution(base_gaussian, transforms)

However, I am getting strange behavior.


cuda:1
ln118 Using action angle [3.0207381]
ln118 Using action angle [-2.0774028]
ln118 Using action angle [0.5494586]
ln118 Using action angle [0.35132977]
ln118 Using action angle [2.7118604]
ln118 Using action angle [2.5479658]
ln118 Using action angle [2.314454]
ln118 Using action angle [1.4332277]
ln118 Using action angle [2.564674]
ln118 Using action angle [0.07476136]
ln118 Using action angle [3.045766]
ln118 Using action angle [2.066111]
ln118 Using action angle [2.1533582]
ln118 Using action angle [0.4264732]
ln118 Using action angle [-0.37309715]
ln118 Using action angle [2.545832]
ln118 Using action angle [2.7551084]
ln118 Using action angle [2.5142677]
ln118 Using action angle [1.9357533]
ln118 Using action angle [2.443508]
Early stopping at step 1 due to reaching max kl.
Early stopping at step 1 due to reaching max kl.
wandb name: stoic-cherry-76
------------------------------------------
|                Epoch |               0 |
| AverageEpFinalReward |            15.3 |
|     StdEpFinalReward |            12.7 |
|     MaxEpFinalReward |            37.2 |
|     MinEpFinalReward |            5.53 |
|                EpLen |               5 |
|         AverageVVals |           0.631 |
|             StdVVals |           0.179 |
|             MaxVVals |           0.826 |
|             MinVVals |           0.263 |
|    TotalEnvInteracts |              20 |
|           DeltaLossV |       -2.74e+03 |
|             StopIter |               1 |
|                 Time |             8.2 |
------------------------------------------
ln118 Using action angle [-3.141592]
ln118 Using action angle [-3.1415927]
ln118 Using action angle [-3.1411002]
ln118 Using action angle [-3.1415]
ln118 Using action angle [-3.1415927]
ln118 Using action angle [3.1000373]
ln118 Using action angle [-3.1415732]
ln118 Using action angle [-3.1415896]
ln118 Using action angle [-3.1411572]
ln118 Using action angle [3.1383228]
ln118 Using action angle [-2.7440572]
ln118 Using action angle [-3.141571]
ln118 Using action angle [-3.1415923]
ln118 Using action angle [-3.1415927]
ln118 Using action angle [-3.1415927]
ln118 Using action angle [-3.14097]
ln118 Using action angle [-3.1415899]
ln118 Using action angle [-3.141566]
ln118 Using action angle [-3.1415927]
ln118 Using action angle [-3.1415927]
Early stopping at step 3 due to reaching max kl.
wandb name: stoic-cherry-76
------------------------------------------
|                Epoch |               1 |
| AverageEpFinalReward |            37.2 |
|     StdEpFinalReward |        0.000489 |
|     MaxEpFinalReward |            37.2 |
|     MinEpFinalReward |            37.2 |
|                EpLen |               5 |
|         AverageVVals |              53 |
|             StdVVals |        7.03e-06 |
|             MaxVVals |              53 |
|             MinVVals |              53 |
|    TotalEnvInteracts |              40 |
|           DeltaLossV |       -4.38e+03 |
|             StopIter |              41 |
|                 Time |            18.9 |
------------------------------------------
ln118 Using action angle [nan]
Traceback (most recent call last):
  File "test/test_ppo_bandu.py", line 157, in <module>
    test_bandu(args.rank, 1)
  File "test/test_ppo_bandu.py", line 115, in test_bandu
    world_size=world_size)
  File "/home/richard/improbable/spinningup/spinup/algos/pytorch/ppo/ppo.py", line 474, in ppo
    next_o, r, d, _ = env.step(a)
  File "/home/richard/improbable/venvs/spinningup_venv/lib/python3.6/site-packages/gym/wrappers/time_limit.py", line 16, in step
    observation, reward, done, info = self.env.step(action)
  File "/home/richard/improbable/spinningup/envs/bandu.py", line 143, in step
    self.meshes = R.from_quat(action_quat).apply(self.meshes.reshape(nO*num_points, 3)).reshape(nO, num_points, 3)
  File "/home/richard/improbable/venvs/spinningup_venv/lib/python3.6/site-packages/scipy/spatial/transform/rotation.py", line 479, in from_quat
    return cls(quat, normalize=True)
  File "/home/richard/improbable/venvs/spinningup_venv/lib/python3.6/site-packages/scipy/spatial/transform/rotation.py", line 385, in __init__
    norms = scipy.linalg.norm(quat, axis=1)
  File "/home/richard/improbable/venvs/spinningup_venv/lib/python3.6/site-packages/scipy/linalg/misc.py", line 142, in norm
    a = np.asarray_chkfinite(a)
  File "/home/richard/improbable/venvs/spinningup_venv/lib/python3.6/site-packages/numpy/lib/function_base.py", line 499, in asarray_chkfinite
    "array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNs

wandb: Waiting for W&B process to finish, PID 8529
wandb: Program failed with code 1.  Press ctrl-c to abort syncing.
wandb:                                                                                
wandb: Find user logs for this run at: /home/richard/improbable/spinningup/wandb/run-20210511_054731-2ur2v5o1/logs/debug.log
wandb: Find internal logs for this run at: /home/richard/improbable/spinningup/wandb/run-20210511_054731-2ur2v5o1/logs/debug-internal.log
wandb: Run summary:
wandb:                  Epoch 1
wandb:   AverageEpFinalReward 37.17179
wandb:       StdEpFinalReward 0.00049
wandb:       MaxEpFinalReward 37.1725
wandb:       MinEpFinalReward 37.17122
wandb:                  EpLen 5.0
wandb:           AverageVVals 52.97508
wandb:               StdVVals 1e-05
wandb:               MaxVVals 52.97507
wandb:               MinVVals 52.97507
wandb:      TotalEnvInteracts 40
wandb:             DeltaLossV -4384.70703
wandb:               StopIter 41.0
wandb:                   Time 18.8741
wandb:               _runtime 24
wandb:             _timestamp 1620726475

Basically, the first epoch, when we are just sampling from the tanhnormal gaussian (the “Using action angle [3.0207381]” etc)

It works as expected, giving a relatively uniform distribution over the rotation angles I am trying to model with the tanh normal.

But as soon as I take gradient updates over the first batch, all the rotation angles are getting pushed to the boundaries -3.14 and 3.14. Eventually, the samples become NaNs.

Why is this happening?