DataParallel with stateful model

I have an RNN cell where I keep the activations as model state. I realize this is sort of against normal PyTorch conventions, so maybe I shouldn’t be doing it, but it has been working fine, and I think it has advantages in the context of my larger model (although they might be illusory or debatable).

I recently tried to use DataParallel, and it didn’t work; the activations (which are just Variables set on the model) do not persist between calls to forward.

Maybe modules should support something like a register_device_variable method that would allow Variables to persist between calls to forward (but not across checkpoints)?

But in the meantime, is there a good workaround? I thought about having a static array of the Variables and then having forward find the Variable to use by getting the device of one of its parameters and then store the new Variable back before returning, but I haven’t tried it yet.