I am trying to pass a set of gpu models to a pool of processes, so something like the example code below:
(torch version 1.0.0.dev20181015
)
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)
import torch
import torch.nn as nn
import time
def run(inpt):
workers, worker_queue = inpt
while worker_queue.empty():
time.sleep(0.1)
worker_id = worker_queue.get()
model = workers[worker_id]
data = torch.rand(5, 10)
data = data.cuda()
with torch.no_grad():
result = model(data)
worker_queue.put(worker_id)
return result
def main():
workers = []
worker_queue = mp.Manager().Queue() # mp.Queue()
for worker_id in range(10):
worker_model = nn.BatchNorm1d(10)
worker_model.cuda()
worker_model.eval()
workers.append(worker_model)
worker_queue.put(worker_id)
pools = mp.Pool(10)
jobs = [(workers, worker_queue) for _ in range(20)] # jobs
result = []
for res_id, res in enumerate(pools.imap_unordered(run, jobs)):
result.append(res)
print(result)
if __name__=="__main__":
main()
However, when there is BatchNorm, the following error is raised in torch._utils.rebuild_cuda_tensor
:
RuntimeError: Expected object of data type 4 but got data type 6 for argument #2 'source'
It is probably caused by the attribute bn.num_batches_tracked, which is supposed to be type 4 (int64_t), but looks like the error message is reporting data type 6 (float). What is likely the cause of this issue? Is there any way to fix it?