Pytorch documentation for Batchnorm

Yeah, that’s probably the issue and might be considered a bug.

As you can see here, set_training changes the current training mode of the model to the passed argument mode.
By default torch.onnx.export uses training=False, which should be fine.

However, since you are not setting the complete model to eval, net.training will still return True:

net.bn.eval()
print(net.training)
> True
print(net.bn.training)
> False

While this is your desired use case, set_training only checks the training attribute of the parent model and sets the complete model to its “old” mode again:

torch.onnx.export(net, x, 'tmp.onnx')
print(net.training)
> True
print(net.bn.training)
> True

This will of course cause the next forward call to update the running statistics, so you should set the batchnorm layer to eval again after exporting the model using onnx.

The proper approach would maybe be to restore the training attribute for each submodule recursively, but I’m not sure if that’s an edge case.
Anyway, feel free to open an issue and link to this topic so that this can be discussed with the ONNX devs.

1 Like