Basically I have one GPU handling the following computations :
S, = self.temporal_sources()
A, B = self.spatial_sources()
chunk_S = torch.cuda.comm.scatter(S, range(self.n_gpu))
broadcast_A = torch.cuda.comm.broadcast(A, range(self.n_gpu))
Until I start using 2 as I would go out of memory if not
chunk_SA = [torch.mm(Si,Ai) for Si, Ai in zip(chunk_S, broadcast_A)]
broadcast_B = torch.cuda.comm.broadcast(B, range(self.n_gpu))
ll_list = [self.likelihood((SiA, target, Bi)) for SiA, target, Bi in zip(chunk_SA, self.target, broadcast_B)]
ll = torch.cuda.comm.reduce_add(ll_list)
Knowing that target is a list for the self.target is a list containing the scattered target data.
If I do it this way it doesn’t work.
But if I do the following it works fine :
chunk_S = [Si.to(‘cuda:’ + str(gpu_id)) for gpu_id, Si in enumerate(S.chunk(self.n_gpu, 0))]
chunk_SA = [torch.mm(Si, A.to(‘cuda:’ + str(Si.get_device()))) for Si in chunk_S]
ll_list = [self.likelihood((SiA, target, B.to(‘cuda:’ + str(SiA.get_device())))) for SiA, target in zip(chunk_SA, self.target)]
ll = sum([ll.to(‘cuda:0’) for ll in ll_list])
I’ve been doing many tests with different seeds and parameters so I’m confident that this was not luck. The model takes a lot of time to converge and is highly unstable at the beginning. But in the working implementation this the pattern of the loss at the beginning :
As you can see it is pretty messy but at the end it converges.
However when I use torch.comm.reduce_add it gets crazy very fast :
Hope this helps.