I suppose the model and the input should be on the same device. If you have multiple GPUs, one thing you could do is use the mode parallel in pytorch and distribute your model layers across the multiple GPUs and keep only a few or even one layer of the model on the GPU which has your input.
Might not be what you wanted but may help. The mode parallel tutorial is here -
Try using Data Parallelism. It’s easy to use if you have access to multiple GPUs. Make sure that the batches are still small enough to fit on a single machine.