I have a function annotated with torch.jit.script_method whose forward pass includes self.training for dropout
class Convs(torch.jit.ScriptModule):
def __init__(self):
self.conv = nn.Conv(etc)
@torch.jit.script_method
def forward(self, x, input_lengths):
for conv in self.convolutions:
x = F.dropout(F.relu(self.conv(x)), 0.5, self.training)
When loading a state_dict onto the jitted model, there’s a Missing Key error
Traceback (most recent call last):
File "jit_test.py", line 54, in <module>
t.load_state_dict(sd_new)
File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 769, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for Tacotron2:
Missing key(s) in state_dict: "encoder.training".
How should self.training be used with torch.jit such that the behavior mimics non-jit pytorch?