Results are different when saving in TorchScript Format through train/eval mode

Hi community. Here is the issue report I created. Can someone provide some insight or feedback in it? Thanks!

:bug: Describe the bug

Seems like the dropout layer is not disabled when saving a Module in train mode through jit.

Reproduce

Creating and saving the model via:

  • states
  • JIT script (with model in train mode)
  • JIT script (with model in eval mode)
import torch
import torch.nn as nn
import numpy as np

torch.manual_seed(1)
np.random.seed(1)
x = torch.tensor(np.random.random(100).astype(np.float32).reshape(1, 1, 10, 10))
net = nn.Sequential(nn.Flatten(), nn.Linear(100, 4), nn.Dropout(0.5))
# set to eval mode
net.eval()
print(net(x).tolist())

# saving states
torch.save(net.state_dict(), "/tmp/saved_states.pth")

# saving jit script in train mode
net.train()
traced_script_module = torch.jit.trace(net, x)
traced_script_module.save("/tmp/traced_train.pt")

# saving jit script in eval mode
net.eval()
traced_script_module = torch.jit.trace(net, x)
traced_script_module.save("/tmp/traced_eval.pt")

Loading back model in three approaches and evaluating them.

reloaded_net_states = nn.Sequential(nn.Flatten(), nn.Linear(100, 4), nn.Dropout(0.5))
states = torch.load("/tmp/saved_states.pth")
reloaded_net_states.load_state_dict(states)

loaded_nets = {
    "states": reloaded_net_states,
    "traced_train": torch.jit.load("/tmp/traced_train.pt"),
    "traced_eval": torch.jit.load("/tmp/traced_eval.pt"),
}

for name, loaded_net in loaded_nets.items():
    print("testing", name)
    for i in range(3):
        # set to eval mode
        loaded_net.eval()
        print(loaded_net(x).tolist())

Output

[[-0.10291355848312378, 0.3682890832424164, -0.04384330287575722, -0.16762569546699524]]
testing states
[[-0.10291355848312378, 0.3682890832424164, -0.04384330287575722, -0.16762569546699524]]
[[-0.10291355848312378, 0.3682890832424164, -0.04384330287575722, -0.16762569546699524]]
[[-0.10291355848312378, 0.3682890832424164, -0.04384330287575722, -0.16762569546699524]]
testing traced_train
[[-0.20582711696624756, 0.0, -0.0, -0.0]]
[[-0.20582711696624756, 0.7365781664848328, -0.0, -0.3352513909339905]]
[[-0.0, 0.7365781664848328, -0.0, -0.0]]
testing traced_eval
[[-0.10291355848312378, 0.3682890832424164, -0.04384330287575722, -0.16762569546699524]]
[[-0.10291355848312378, 0.3682890832424164, -0.04384330287575722, -0.16762569546699524]]
[[-0.10291355848312378, 0.3682890832424164, -0.04384330287575722, -0.16762569546699524]]

Expected Output

The results of saved jit scripts should be the same.

Versions

Collecting environment information…
PyTorch version: 1.10.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.9.7 | packaged by conda-forge | (default, Sep 29 2021, 19:20:46) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.11.0-38-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: NVIDIA TITAN RTX
GPU 1: NVIDIA TITAN RTX

Nvidia driver version: 495.29.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.3.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.3.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.4
[pip3] torch==1.10.0
[pip3] torchaudio==0.10.0+cu113
[pip3] torchvision==0.11.1+cu113
[conda] numpy 1.21.4 py39hdbf815f_0 conda-forge
[conda] torch 1.10.0+cu113 pypi_0 pypi
[conda] torchaudio 0.10.0+cu113 pypi_0 pypi
[conda] torchvision 0.11.1+cu113 pypi_0 pypi

PS: I also reproduced it on version 1.9.0

Answered here.
TL;DR: tracing a module records the actual execution (with or without dropout) and you thus cannot change the behavior afterwards.

1 Like

Thank you so much for your help. I will use scripting instead of tracing.