The dtype of optimizer states in PyTorch AMP training

Hello. I know that when using PyTorch AMP for training, the model weights are of type float32, while the gradients are 16 bits. I’m now wondering what the type of optimizer states is. For example, if I use Adam optimizer for training in the PyTorch AMP context, will the dtype of the states of Adam be 32bit or 16bit? Thanks.

The optimizer states should use the same dtype as the parameters but you can quickly check it by printing the optimizer.state_dict().

1 Like