I’m currently writing a fairly complex library that exposes various Modules and I have the following requirement in multiple places that I have a hard time satisfying.
I need to be able to define a constant tensor in a Module that is initialized with some user-defined value and shape. I want the tensor to move to the correct target hardware when doing model.to(some_device), but I don’t want the value to be part of the state_dict when the model is saved/loaded, so that if the user pass a different value at init time, that value is taken and not overwritten when a state_dict is loaded.
There are possible ways around this, e.g. define a Buffer, manually exclude it from the state_dict when the model is saved, and the reload with strict=False. My issue with this is that I would basically pass on the problem to the library’s user, and having a requirement on strict=False would also expose me to all sort of other weaknesses. Instead, I want to be able to manage the problem from within the library, i.e. from the Module definition.
My only idea at the moment is to use a Buffer and override the load_state APIs of the Module, but I have never seen it done anywhere. Any sort of pointer would be highly appreciated.
The implementation should work with the JIT compiler in a DataParallel context.
Alessandro