Based on your debugging so far I would guess that the NaNs are created inside the model at one point.
A “brute force” approach would be to register forward hooks to all layers and check their output for invalid values to further narrow down the first occurrence of the NaNs.
To do so, you could use this code and use torch.isfinite(tensor).all()
inside the hook.
Once you are seeing the first NaN output, you could then also check the inputs to this layer as well as its parameters and check, if the activations is overflowing or why the tensor contains invalid values.