I have a model which consists of a shared body and multiple task-specific heads. For each batch, I forward- and backward-pass through the shared body and one of the task-specific heads.
Keeping all task-specific heads in the CUDA memory is not possible on the GPUs available to me. Therefore, I want to dynamically load the required task-specific head for each batch into the GPU memory and unload the previous one. Therefore, I plan two modifications to my existing (rather long) code:
- Override
to()
such that it only transfers the shared body to GPU memory. - Add some
to()
calls inforward()
to load the required task-specific head to a local variable temporarily, such that it is garbage collected at the end of the optimizer step.
My question is: Is it sufficient to override to()
or do I have to also override cuda()
and cpu()
? I noticed that there is some logic hardcoded in Module.to()
, so I think I need to copy that, right?
Thank you for your support!