Model state_dict full of NaNs but the code is running?

I have saved checkpoints of my model every n iterations while training. After I load the checkpoints, I quickly get networks full of NaNs which soon crash. After investigating I’ve found the state_dicts I’ve saved are full of NaNs. However, I believe is working as intended and not corrupting my data. This suggests that my networks are training while full of NaNs, but this doesn’t make sense because the networks continue to run without issue. It’s only until I load the serialized state_dicts that the code crashes.

It seems like either the serialization code is corrupting my network’s parameters or the network is running without issue while full of NaN values and only crashing once it gets deserialized and loaded back in.

My loading/saving code is simple:

# Save model parameters
  def save_checkpoint(self, env_name, suffix="", ckpt_path=None):
      if not os.path.exists('checkpoints/'):
      if ckpt_path is None:
          ckpt_path = "checkpoints/sac_checkpoint_{}_{}".format(env_name, suffix)          
      print('Saving models to {}'.format(ckpt_path)){'policy_state_dict': self.policy.state_dict(),
                  'critic_state_dict': self.critic.state_dict(),
                  'critic_target_state_dict': self.critic_target.state_dict(),
                  'critic_optimizer_state_dict': self.critic_optim.state_dict(),
                  'policy_optimizer_state_dict': self.policy_optim.state_dict()}, ckpt_path)

  # Load model parameters
  def load_checkpoint(self, ckpt_path, evaluate=False):
      print('Loading models from {}'.format(ckpt_path))
      if ckpt_path is not None:
          checkpoint = torch.load(ckpt_path)

          if evaluate:

My networks are also fairly simple MLPs

class QNetwork(torch.jit.ScriptModule):
    def __init__(self, num_inputs, num_actions, hidden_dim):
        super(QNetwork, self).__init__()

        # Q1 architecture
        self.linear1 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.linear3 = nn.Linear(hidden_dim, 1)

        # Q2 architecture
        self.linear4 = nn.Linear(num_inputs + num_actions, hidden_dim)
        self.linear5 = nn.Linear(hidden_dim, hidden_dim)
        self.linear6 = nn.Linear(hidden_dim, 1)


    def forward(self, state, action):
        xu =[state, action], 1)
        x1 = F.relu(self.linear1(xu))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        x2 = F.relu(self.linear4(xu))
        x2 = F.relu(self.linear5(x2))
        x2 = self.linear6(x2)

        return x1, x2

I’ve tried changing my learning rate and using normalization layers but no luck. Any ideas? I can share more code if that would help.

The issue sounds quite strange, since you should see a NaN output if any parameter of the model is already invalid (Inf or NaN).
Could you check the state_dict after saving it by directly re-loading it and comparing the values for each parameter? Also, which PyTorch version are you using? We’ve had an issue in PyTorch ~1.6 or so, which created invalid values while saving a state_dict containing parameters on the GPU.
If you are using an older PyTorch release, try to use _use_new_zipfile_serialization=False in or update to the current stable release.