Stateless constant tensor w/ DataParallel, JIT

I’m currently writing a fairly complex library that exposes various Modules and I have the following requirement in multiple places that I have a hard time satisfying.
I need to be able to define a constant tensor in a Module that is initialized with some user-defined value and shape. I want the tensor to move to the correct target hardware when doing model.to(some_device), but I don’t want the value to be part of the state_dict when the model is saved/loaded, so that if the user pass a different value at init time, that value is taken and not overwritten when a state_dict is loaded.
There are possible ways around this, e.g. define a Buffer, manually exclude it from the state_dict when the model is saved, and the reload with strict=False. My issue with this is that I would basically pass on the problem to the library’s user, and having a requirement on strict=False would also expose me to all sort of other weaknesses. Instead, I want to be able to manage the problem from within the library, i.e. from the Module definition.
My only idea at the moment is to use a Buffer and override the load_state APIs of the Module, but I have never seen it done anywhere. Any sort of pointer would be highly appreciated.
The implementation should work with the JIT compiler in a DataParallel context.

Alessandro

To answer myself, I partially solved the problem for by splitting the value into a buffer with value 0 and a python constant with the user-defined value, and then summing them in the forward pass. Hopefully the JIT is smart enough to see that they are both constant, and it propagates them when required, but I haven’t verified it yet.
There is still has a requirement on strict=False though, since loading a pretrained model to which any of these layers have been added will complain about the missing buffer.
It would be nice to see a proper API for this use case, especially give the fact that constant values could be easily optimized by the JIT.

Alessandro