We have a model in which one of the layers is used to generate logits for a OneHotCategorical distribution. When autocast is disabled the code seems to work, but when we enable autocast it will run for a seemingly random amount of iterations and then fail with the following error:
Traceback (most recent call last):
File "C:\Users\emh\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "C:\Users\emh\AppData\Local\Programs\Python\Python39\lib\runpy.py", line 87, in _run_code
exec(code, run_globals)
File "C:\Users\emh\Desktop\thesis-repo\venv\Scripts\rl_thesis.exe\__main__.py", line 7, in <module>
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\bin\rl_thesis.py", line 30, in entry_func
script_module.entry_func(remaining_args + help_args)
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\bin\train.py", line 78, in entry_func
model.train(eval_env)
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\models\Dreamer\model.py", line 264, in train
train_driver(train_policy, steps=self.conf.eval_freq)
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\driver.py", line 60, in __call__
[fn(tran, worker=i, **self._kwargs) for fn in self._on_steps]
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\driver.py", line 60, in <listcomp>
[fn(tran, worker=i, **self._kwargs) for fn in self._on_steps]
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\models\Dreamer\model.py", line 200, in train_step
self.state, mets = agent.learn(train_batch, self.state)
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\Agent.py", line 69, in learn
state, outputs, mets = self.wm.learn(data, state)
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\WorldModel.py", line 33, in learn
model_loss, state, outputs, metrics = self.loss(data, state)
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\WorldModel.py", line 91, in loss
metrics["prior_entropy"] = self.rssm.distribution_from_stats(
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\RSSM.py", line 287, in distribution_from_stats
raise e
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\rl_thesis\dreamer\RSSM.py", line 281, in distribution_from_stats
return torch.distributions.Independent(OneHotDist(logits=state.logit), 1)
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\torch\distributions\one_hot_categorical.py", line 44, in __init__
super(OneHotCategorical, self).__init__(batch_shape, event_shape, validate_args=validate_args)
File "C:\Users\emh\Desktop\thesis-repo\venv\lib\site-packages\torch\distributions\distribution.py", line 55, in __init__
raise ValueError(
"ValueError: Expected parameter probs (Tensor of shape (16, 16, 16, 16)) of distribution OneHotDist() to satisfy the constraint Simplex(), but found invalid values:"
Part of the code that fails:
try:
return torch.distributions.Independent(OneHotDist(logits=state.logit), 1)
except ValueError as e:
print("Failed")
print(f"is inf: {torch.isinf(state.logit).any()}") # prints false
print(f"is nan: {torch.isnan(state.logit).any()}") # prints false
print(f"type: {state.logit.type()}") #prints torch.cuda.HalfTensor
raise e
The line within the try
block is the one that fails, so I added some lines to get more info about the logits.
The class OneHotDist
is one we implemented ourselves and it inherits from torch.distributions.OneHotCategorical
but does not implement its own __init__
so we do not alter the logits before they are passed to the __init__
of the super class.
Do we need to handle the logits in a special way when using autocast? Like manually casting the logtits to float32, scaling them with a GradScaler
or use the custom_fwd
/custom_bwd
decorators somewhere?