The newly introduced package
torch.fx is quite useful for pruning and architectural search of networks. However, I have met problems resuming the training when the training job is stopped unexpectedly because the directed graph of this network has already changed. For example, a convolution layer was removed from the directed graph and it should get bypassed in the resumed job, however, this layer is still used in
nn.Module.forward() in the python script. In other words, the model checkpoint saved by
torch.save() may not match the original model definition.
The checkpoint does not include directed graph information so you have to continue to depend on the python definition, unless you save the checkpoint using
torch.jit.script(). But torchscript would not support ‘torch.fx’ any longer because it seems graph of torchscript is not changeable.
Are there good suggestions to resume the training of a network whose graph is already changed?