I am trying to pass a set of gpu models to a pool of processes, so something like the example code below:
import torch.multiprocessing as mp mp.set_start_method('spawn', force=True) import torch import torch.nn as nn import time def run(inpt): workers, worker_queue = inpt while worker_queue.empty(): time.sleep(0.1) worker_id = worker_queue.get() model = workers[worker_id] data = torch.rand(5, 10) data = data.cuda() with torch.no_grad(): result = model(data) worker_queue.put(worker_id) return result def main(): workers =  worker_queue = mp.Manager().Queue() # mp.Queue() for worker_id in range(10): worker_model = nn.BatchNorm1d(10) worker_model.cuda() worker_model.eval() workers.append(worker_model) worker_queue.put(worker_id) pools = mp.Pool(10) jobs = [(workers, worker_queue) for _ in range(20)] # jobs result =  for res_id, res in enumerate(pools.imap_unordered(run, jobs)): result.append(res) print(result) if __name__=="__main__": main()
However, when there is BatchNorm, the following error is raised in
RuntimeError: Expected object of data type 4 but got data type 6 for argument #2 'source'
It is probably caused by the attribute bn.num_batches_tracked, which is supposed to be type 4 (int64_t), but looks like the error message is reporting data type 6 (float). What is likely the cause of this issue? Is there any way to fix it?