Torch.cuda.amp.autocast breaks simplex constraint

We have a model in which one of the layers is used to generate logits for a OneHotCategorical distribution. When autocast is disabled the code seems to work, but when we enable autocast it will run for a seemingly random amount of iterations and then fail with the following error:

Traceback (most recent call last):
  File "C:\Users\emh\AppData\Local\Programs\Python\Python39\lib\", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "C:\Users\emh\AppData\Local\Programs\Python\Python39\lib\", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\emh\Desktop\thesis-repo\venv\Scripts\rl_thesis.exe\", line 7, in <module>
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\bin\", line 30, in entry_func
    script_module.entry_func(remaining_args + help_args)
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\bin\", line 78, in entry_func
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\models\Dreamer\", line 264, in train
    train_driver(train_policy, steps=self.conf.eval_freq)
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\", line 60, in __call__
    [fn(tran, worker=i, **self._kwargs) for fn in self._on_steps]
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\", line 60, in <listcomp>
    [fn(tran, worker=i, **self._kwargs) for fn in self._on_steps]
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\models\Dreamer\", line 200, in train_step
    self.state, mets = agent.learn(train_batch, self.state)
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\", line 69, in learn
    state, outputs, mets = self.wm.learn(data, state)
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\", line 33, in learn
    model_loss, state, outputs, metrics = self.loss(data, state)
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\", line 91, in loss
    metrics["prior_entropy"] = self.rssm.distribution_from_stats(
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\", line 287, in distribution_from_stats
    raise e
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\", line 281, in distribution_from_stats
    return torch.distributions.Independent(OneHotDist(logits=state.logit), 1)
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\torch\distributions\", line 44, in __init__
    super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
  File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\torch\distributions\", line 55, in __init__
    raise ValueError(
"ValueError: Expected parameter probs (Tensor of shape (16, 16, 16, 16)) of distribution OneHotDist() to satisfy the constraint Simplex(), but found invalid values:"

Part of the code that fails:

    return torch.distributions.Independent(OneHotDist(logits=state.logit), 1)
except ValueError as e:
    print(f"is inf: {torch.isinf(state.logit).any()}") # prints false
    print(f"is nan: {torch.isnan(state.logit).any()}") # prints false
    print(f"type: {state.logit.type()}") #prints torch.cuda.HalfTensor
    raise e

The line within the try block is the one that fails, so I added some lines to get more info about the logits.
The class OneHotDist is one we implemented ourselves and it inherits from torch.distributions.OneHotCategorical but does not implement its own __init__ so we do not alter the logits before they are passed to the __init__ of the super class.

Do we need to handle the logits in a special way when using autocast? Like manually casting the logtits to float32, scaling them with a GradScaler or use the custom_fwd/custom_bwd decorators somewhere?

I don’t know if casting the logits to float32 would help (you could try it) or if it’s already done, but the values are already “invalid”.
Do you know what “invalid” means in this case and which values are unexpected in state.logit?

Simply casting the logits to float32 worked. Thank you for the response.