Encoder generates NaN during VAE learning process

I looked into it and this seems to be a common problem, but I could not solve the problem.

I am trying to take one of the GymEnv environments as a pixel and have VAE learn it. However, in the process, the VAE value is NaN and the learning does not seem to be working.

Here is minimum code that can confirm the occurrence of NaN problem(since it is too long code, I upload it to pastebin):

I have examined the output results within VAE and found that NaN is occurring at Encoder. However, I am stumped as to how to resolve this issue.

This will result in a Tensor that has become NaN in the VAE encoder after several rounds of learning. Below is an example of the results when the problem was actually reproduced in my environment.

input isnan1:False
encoded_x isnan2:False
encoded_x isnan3:False
encoded_x isnan4:False
torch.Size([1000, 32])
input isnan1:False
encoded_x isnan2:False
encoded_x isnan3:False
encoded_x isnan4:False
torch.Size([1000, 32])
input isnan1:False
encoded_x isnan2:False
encoded_x isnan3:False
encoded_x isnan4:False
torch.Size([1000, 32])
input isnan1:False
encoded_x isnan2:False
encoded_x isnan3:False
encoded_x isnan4:False
torch.Size([1000, 32])
input isnan1:False
encoded_x isnan2:True
encoded_x isnan3:True
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<CloneBackward0>)
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<ReluBackward0>)

Edit

I forgot write about installed packages. Here is result of pip freeze:

absl-py==2.0.0
certifi==2023.7.22
cffi==1.16.0
charset-normalizer==3.3.0
cloudpickle==3.0.0
colorama==0.4.6
contourpy==1.1.1
cycler==0.12.1
filelock==3.12.4
fonttools==4.43.1
fsspec==2023.9.2
glcontext==2.5.0
glfw==1.12.0
gym==0.26.2
gym-notices==0.0.8
gym3==0.3.3
idna==3.4
imageio==2.31.5
imageio-ffmpeg==0.3.0
Jinja2==3.1.2
kiwisolver==1.4.5
MarkupSafe==2.1.3
matplotlib==3.8.0
moderngl==5.8.2
mpmath==1.3.0
mujoco==2.2.0
networkx==3.1
numpy==1.26.1
packaging==23.2
Pillow==10.1.0
procgen==0.10.7
pycparser==2.21
PyOpenGL==3.1.7
pyparsing==3.1.1
python-dateutil==2.8.2
requests==2.31.0
six==1.16.0
sympy==1.12
tensordict==0.2.0
torch==2.1.0
torchrl==0.2.0
torchvision==0.16.0
tqdm==4.66.1
typing_extensions==4.8.0
urllib3==2.0.6
  1. Have you tried substituting in a more standard loss function to see if that’s the issue?
  2. You might try clipping the objective function. See here: Proximal Policy Optimization — Spinning Up documentation
  3. You might also try gradient clipping. See here: torch.nn.utils.clip_grad_norm_ — PyTorch 2.1 documentation

For example:

loss.backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)

optimizer.step()

Lastly, you might try lowering the learning rate of your optimizer.

Thank you for reply.
I let you know about the points that I have been able to confirm.

  1. I changed loss function to BCE version and Gaussian loss version, but VAE’s Encoder output NaN in training phase.

  2. I tried gradient clipping but VAE output NaN same as before.

  3. Try lower learning rate (10^-4 to 10^-6) though, the result does not change from NaN. It seems to take longer training steps to output NaN by lowering learning rate, but it eventually output NaN.

Thanks for trying. That eliminates a few possibilities.

Was looking over the code again, and where are you scaling the image values from 0-255 to 0-1? I see you imported ToTensorImage but didn’t use it.

Thanks for your help! My problem seems to be solved by adding ToTensorImage!

Just out of my curiosity, why ToTensorImage prevent to appear a NaN? I think it’s related to the gradient computation, but I’m not really sure what’s actually happening.

Images by default use int8 and the values fall between 0 and 255.

But for machine learning, you want the values scaled from 0 to 1. That class will rescale the images, and also permute the channels dim.