Depending on the model architecture, you could try to apply model sharding as shown in this exmaple.
This approach would store submodules on different devices and transfer the output to the corresponding device in the forward
method.
Would this be possible or is a single layer already creating the OOM issue?