Hello!
I want to train a model with multiple GPUs. For that I am using Lightning since the API makes it easier. However, I am using a Merlin-dataloader module as data module for the Lightning trainer. This way, I call the trainer like this: trainer.fit(model=model, datamodule=Merlin_module)
.
In my Merlin module (Merlin_module
), each GPU should access to a different part of the dataset, which is determined by the name of the files it retrieves. For getting the correct files, the datamodule must access to the device name of the process that is accessing the files. So now, the question is: how can I know, within the datamodule, which is the GPU that is executing it?
My datamodule has a function train_dataloader
that generates the dataset and puts it in a Loader
from Merlin. Ideally, I would like to do something like:
def train_dataloader(self):
rank = get_rank() # get rank of device
data = get_dataset(rank) # get correspondent files for the device
return Loader(data)
Nevertheless, I don’t know how to access the rank within the datamodule. Hence, my questions is: can that be done? If so, how?
Cause maybe the datamodule is just called in the main process and then the problem is even more complex to be solved.
Thanks a lot!!