The following example will raise error:
from multiprocessing import set_start_method, Queue, Process
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from random import shuffle
class Model(nn.Module):
def __init__(self, obs_n, act_n):
super(Model, self).__init__()
self.fcs = nn.Sequential(
nn.Linear(obs_n, 32),
nn.ReLU(True),
nn.Linear(32, 16),
nn.ReLU(True),
nn.Linear(16, act_n),
)
def forward(self, x):
return self.fcs(x)
class iData:
def __init__(self, data, batch_size):
self.data = data
self.batch_size = batch_size
def __len__(self):
return self.data.shape[0]
def __iter__(self):
self.index_set = np.arange(self.data.shape[0])
np.random.shuffle(self.index_set)
return self
def __next__(self):
assert len(self.index_set) > 0
if len(self.index_set) < self.batch_size:
raise StopIteration
else:
indices = self.index_set[:self.batch_size]
self.index_set = self.index_set[self.batch_size:]
return torch.from_numpy(self.data.take(indices, axis=0))
def run_model(model, q):
model.cuda()
x = torch.rand(32, 3).cuda()
y = model(x)
q.put(y.detach().cpu().numpy()[0])
def main():
model = Model(3, 3)
print('Training...')
model.cuda()
data = np.random.rand(1000, 3)
dataset = iData(data, 32)
batch_itr = iter(dataset)
for step in range(1000):
try:
x = next(batch_itr)
except:
batch_itr = iter(dataset)
x = next(batch_itr)
y = model(x.float().cuda())
if step % 250 == 0:
print(y.detach().cpu().numpy()[0])
print('----')
model.cpu()
queue = Queue(maxsize=20)
ps = []
for _ in range(4):
xmodel = Model(3, 3)
xmodel.load_state_dict(model.state_dict())
print(list(xmodel.fcs[0].named_parameters()))
p = Process(target=run_model, args=(xmodel, queue))
p.daemon = True
p.start()
ps.append(p)
for _ in range(4):
print(queue.get())
for p in ps:
p.join()
print('---')
main()