Yes, you could use model sharding and push different parts of the model (parameters) to specific devices via to('cuda:id') and use the same operation in the forward to push the activation to the right device.
PyTorch Lightning provides a beta of sharded training, which might be interesting for you.
CC @williamFalcon for more information ![]()