How to use Python multiprocessing with cuda?

Hello everyone, I’m trying to train a model which I split each layer into different GPU:

class ToyModel(nn.Module): // Consider one ToyModel as one layer in the original model
    def __init__(self, x_in, x_out, device):
        super().__init__()
        self.layer = nn.Linear(x_in, x_out).to(device)
        self.device = device
        
    def forward(self, x, y):
        x = self.layer(x.to(self.device))
        self.loss = nn.MSELoss(x, y.to(device))
        return x.detach()

def train(arg):
    model = arg[0]
    opt = arg[1]
    queue_in = arg[2]
    queue_out = arg[3]
    index = arg[4]
    print(f'{index} layer start training')
    for i in range(10):  // Assuming there are only 10 batches of data
        out = model(queue_in.get())
        queue_out.put(out)
        model.loss.backward()
        opt.step()
        opt.zero_grad()
    print(f'{index} layer finish training')
        
if __name__ == '__main__':
    
    mp.set_start_method('spawn')
    model = {}
    opt = {}
    queue = {}
    m = mp.Manager()
    for i in range(3):  // Assumming original model has three layers
        device = 'cuda:' + str(i)
        model[i] = ToyModel(100,100, device)    
        opt[i] = torch.optim.SGD(lr = 0.001)
        queue[i] = m.Queue()
    queue[4] = m.Queue()
    x = torch.randn(2000, 100)
    y = torch.randn(2000, 100)
    x = iter(x.split(200, dim = 0))
    y = iter(y.split(200, dim = 0))
    
    for splitX, splitY in zip(x,y):
        data = (splitX, splitY)
        queue[0].put(data)
    
    p = mp.Pool(3)
    arg = []
    for i in range(3):
        arg.append([model[i], opt[i], queue[i], queue[i+1], i])    

    p.map(train, arg)

However, I find it get stucked when the program try to train the second layer and third layer, could anyone tell me where might be wrong? Thanks.