The model is too big to run on a single GPU with batch1

Depending on the model architecture, you could try to apply model sharding as shown in this exmaple.
This approach would store submodules on different devices and transfer the output to the corresponding device in the forward method.

Would this be possible or is a single layer already creating the OOM issue?

1 Like