Batch mode forward propagation is blocked in subprocess

Hi everyone,

I am trying to implement asynchronous Q-Learning. Each subprocess owns a copy of Deep Q Network, but when performing prediction on batch input, the forward propagation is blocked for no reason.

You can track the original code in following

from torch import nn
import torch
import torch.multiprocessing as mp
import numpy as np
import pdb


class OneHotNGramDQN(nn.Module):
    def __init__(
            self,
            n,
            movie_latent_factor_num,
            layer
    ):
        super(OneHotNGramDQN, self).__init__()
        self._n = n

        self._item_embedding = nn.Embedding(
            10994,
            movie_latent_factor_num
        )

        self._linear_1 = nn.Linear(
            (n + 1) * movie_latent_factor_num,
            layer[0]
        )
        self._linear_2 = nn.Linear(
            layer[0],
            layer[1]
        )
        self._linear_3 = nn.Linear(
            layer[1],
            layer[2]
        )
        self._linear_4 = nn.Linear(
            layer[2],
            layer[3]
        )
        self._linear_5 = nn.Linear(
            layer[3],
            1
        )

        self._relu = nn.ReLU()

    def forward(self, state, action):
        state_x = self._item_embedding(state).view(state.shape[0], -1)
        action_x = self._item_embedding(action).view(state.shape[0], -1)
        x = torch.cat([state_x, action_x], dim=-1)
        x = self._relu(self._linear_1(x))
        x = self._relu(self._linear_2(x))
        x = self._relu(self._linear_3(x))
        x = self._relu(self._linear_4(x))
        x = self._linear_5(x)

        return x


def test_1():
    net = OneHotNGramDQN(10, 32, [640, 320, 160, 50])
    feature_input = torch.LongTensor(np.random.randint(0, 10992, (1, 10)))
    action_input = torch.LongTensor([0])
    print(net(feature_input, action_input).shape)


def test_2():
    net = OneHotNGramDQN(10, 32, [640, 320, 160, 50])
    batch_size = 128
    feature_input = torch.LongTensor(np.random.randint(0, 10992, (batch_size, 10)))
    action_input = torch.LongTensor(np.zeros((batch_size, 1)))
    print(net(feature_input, action_input).shape)


if __name__ == '__main__':
    net = OneHotNGramDQN(10, 32, [640, 320, 160, 50])
    feature_input = torch.LongTensor(np.random.randint(0, 10992, (1, 10)))
    action_input = torch.LongTensor([0])
    print(net(feature_input, action_input).shape)

    batch_size = 128
    feature_input = torch.LongTensor(np.random.randint(0, 10992, (batch_size, 10)))
    action_input = torch.LongTensor(np.zeros((batch_size, 1)))
    print(net(feature_input, action_input).shape)

    # Running without problem
    workers = [mp.Process(target=test_1) for i in range(mp.cpu_count())]
    for worker in workers:
        worker.start()
    for worker in workers:
        worker.join()

    # Forward propagation is somehow blocked
    workers = [mp.Process(target=test_2) for i in range(mp.cpu_count())]
    for worker in workers:
        worker.start()
    for worker in workers:
        worker.join()

I do not quite understand why it happened, although the single instance prediction seems to be ok.
Could anyone kindly give any clue of it?

Thanks a lot.
BR

Can anyone help? Its really weird…