If the crash is due to OOM on the CPU side during model loading rather than inference, I would check if e.g., increasing the swap size for loading could help:
One issue I have had is even if you move tensors to the gpu, they still take up system memory. If that is the issue, maybe streaming the data could help. torch.utils.data — PyTorch 2.0 documentation