Yeah, that’s probably the issue and might be considered a bug.
As you can see here, set_training
changes the current training
mode of the model to the passed argument mode
.
By default torch.onnx.export
uses training=False
, which should be fine.
However, since you are not setting the complete model to eval
, net.training
will still return True
:
net.bn.eval()
print(net.training)
> True
print(net.bn.training)
> False
While this is your desired use case, set_training
only checks the training
attribute of the parent model and sets the complete model to its “old” mode again:
torch.onnx.export(net, x, 'tmp.onnx')
print(net.training)
> True
print(net.bn.training)
> True
This will of course cause the next forward call to update the running statistics, so you should set the batchnorm layer to eval
again after exporting the model using onnx.
The proper approach would maybe be to restore the training
attribute for each submodule recursively, but I’m not sure if that’s an edge case.
Anyway, feel free to open an issue and link to this topic so that this can be discussed with the ONNX devs.