@pritamdamania87 thank you very much for your reply. I believe I have implemented your suggestion, but if you don’t mind checking it for me to make sure it makes sense:
from torch import nn
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import pandas as pd
class CosSimNetwork(nn.Module):
def __init__(self):
super(CosSimNetwork, self).__init__()
def forward(self, embeds):
cos_sims = 12
return cos_sims
def determine_subranges(fullrange: tuple, num_subranges: int):
subranges = []
inc = fullrange[1] // num_subranges
for i in range(fullrange[0], fullrange[1], inc):
subranges.append( (i, min(i+inc, fullrange[1])) )
return( subranges )
def calc_cos_sims(rank, world_size):
dist.init_process_group('gloo', rank=rank, world_size=world_size)
model = CosSimNetwork()
ddp_model = DDP(model, device_ids=[rank])
tmp_df = pd.read_pickle('./embed_pairs_df_million.pkl')
sub_ranges = determine_subranges((0,tmp_df.shape[0]), world_size)
sub_range_tuple = sub_ranges[rank]
data = tmp_df.iloc[sub_range_tuple[0]:sub_range_tuple[1]]
cos_sims = ddp_model(data.to(rank))
def main():
world_size = 4 #since I have 4 GPUs on a single machine
mp.spawn(calc_cos_sims,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == 'main':
main()
If I understand this correctly, mp.spawn()
will create 4 different processes (based off of world_size
) that will run calc_cos_sims()
. The input data is sliced into world_size
chunks and the rank
determines which chunk gets sent to which GPU. Once all the processes are finished, mp.spawn()
will aggregate the results. Is my understanding correct?
A few other questions come to mind:

Is the entire dataset read in 4 times with the tmp_df = pd.read_pickle('./embed_pairs_df_million.pkl')
line? If so, is there a more efficient way to read it in for this setup?

The forward()
method returns the cosine similarity (or it will once I write it) between two embeddings. If calc_cos_sims()
is copied to each process, would I need to replace the mp.spawn()
line with all_cos_sims = mp.spawn()
in order to store the results from all the GPUs?
Thanks in advance for your help!