Breakingdown embeddings between multiple GPUs

I am dealing with a RecSys problem, and I want to split Items and Users among multiple GPUs on one machine. My goal is to avoid TorchRec [i fundamentally don’t get it], and build it from scratch in PyTorch. I appreciate if anyone can point me to a good design using PyTorch.