Torch.cuda.amp.autocast breaks simplex constraint

We have a model in which one of the layers is used to generate logits for a OneHotCategorical distribution. When autocast is disabled the code seems to work, but when we enable autocast it will run for a seemingly random amount of iterations and then fail with the following error:

"ValueError: Expected parameter probs (Tensor of shape (16, 16, 16, 16)) of distribution OneHotDist() to satisfy the constraint Simplex(), but found invalid values:"

Part of the code that fails:

    return torch.distributions.Independent(OneHotDist(logits=state.logit), 1)
except ValueError as e:
    print(f"is inf: {torch.isinf(state.logit).any()}") # prints false
    print(f"is nan: {torch.isnan(state.logit).any()}") # prints false
    print(f"type: {state.logit.type()}") #prints torch.cuda.HalfTensor
    raise e

The line within the try block is the one that fails, so I added some lines to get more info about the logits.
The class OneHotDist is one we implemented ourselves and it inherits from torch.distributions.OneHotCategorical but does not implement its own __init__ so we do not alter the logits before they are passed to the __init__ of the super class.

Do we need to handle the logits in a special way when using autocast? Like manually casting the logtits to float32, scaling them with a GradScaler or use the custom_fwd/custom_bwd decorators somewhere?

I don’t know if casting the logits to float32 would help (you could try it) or if it’s already done, but the values are already “invalid”.
Do you know what “invalid” means in this case and which values are unexpected in state.logit?

Simply casting the logits to float32 worked. Thank you for the response.