Overriding Device Transfer Behavior

I have a model which consists of a shared body and multiple task-specific heads. For each batch, I forward- and backward-pass through the shared body and one of the task-specific heads.

Keeping all task-specific heads in the CUDA memory is not possible on the GPUs available to me. Therefore, I want to dynamically load the required task-specific head for each batch into the GPU memory and unload the previous one. Therefore, I plan two modifications to my existing (rather long) code:

  1. Override to() such that it only transfers the shared body to GPU memory.
  2. Add some to() calls in forward() to load the required task-specific head to a local variable temporarily, such that it is garbage collected at the end of the optimizer step.

My question is: Is it sufficient to override to() or do I have to also override cuda() and cpu()? I noticed that there is some logic hardcoded in Module.to(), so I think I need to copy that, right?

Thank you for your support!