Hi everyone,
I am trying to implement asynchronous Q-Learning. Each subprocess owns a copy of Deep Q Network, but when performing prediction on batch input, the forward propagation is blocked for no reason.
You can track the original code in following
from torch import nn
import torch
import torch.multiprocessing as mp
import numpy as np
import pdb
class OneHotNGramDQN(nn.Module):
def __init__(
self,
n,
movie_latent_factor_num,
layer
):
super(OneHotNGramDQN, self).__init__()
self._n = n
self._item_embedding = nn.Embedding(
10994,
movie_latent_factor_num
)
self._linear_1 = nn.Linear(
(n + 1) * movie_latent_factor_num,
layer[0]
)
self._linear_2 = nn.Linear(
layer[0],
layer[1]
)
self._linear_3 = nn.Linear(
layer[1],
layer[2]
)
self._linear_4 = nn.Linear(
layer[2],
layer[3]
)
self._linear_5 = nn.Linear(
layer[3],
1
)
self._relu = nn.ReLU()
def forward(self, state, action):
state_x = self._item_embedding(state).view(state.shape[0], -1)
action_x = self._item_embedding(action).view(state.shape[0], -1)
x = torch.cat([state_x, action_x], dim=-1)
x = self._relu(self._linear_1(x))
x = self._relu(self._linear_2(x))
x = self._relu(self._linear_3(x))
x = self._relu(self._linear_4(x))
x = self._linear_5(x)
return x
def test_1():
net = OneHotNGramDQN(10, 32, [640, 320, 160, 50])
feature_input = torch.LongTensor(np.random.randint(0, 10992, (1, 10)))
action_input = torch.LongTensor([0])
print(net(feature_input, action_input).shape)
def test_2():
net = OneHotNGramDQN(10, 32, [640, 320, 160, 50])
batch_size = 128
feature_input = torch.LongTensor(np.random.randint(0, 10992, (batch_size, 10)))
action_input = torch.LongTensor(np.zeros((batch_size, 1)))
print(net(feature_input, action_input).shape)
if __name__ == '__main__':
net = OneHotNGramDQN(10, 32, [640, 320, 160, 50])
feature_input = torch.LongTensor(np.random.randint(0, 10992, (1, 10)))
action_input = torch.LongTensor([0])
print(net(feature_input, action_input).shape)
batch_size = 128
feature_input = torch.LongTensor(np.random.randint(0, 10992, (batch_size, 10)))
action_input = torch.LongTensor(np.zeros((batch_size, 1)))
print(net(feature_input, action_input).shape)
# Running without problem
workers = [mp.Process(target=test_1) for i in range(mp.cpu_count())]
for worker in workers:
worker.start()
for worker in workers:
worker.join()
# Forward propagation is somehow blocked
workers = [mp.Process(target=test_2) for i in range(mp.cpu_count())]
for worker in workers:
worker.start()
for worker in workers:
worker.join()
I do not quite understand why it happened, although the single instance prediction seems to be ok.
Could anyone kindly give any clue of it?
Thanks a lot.
BR