Torch.multiprocessing rebuild_cuda_tensor having trouble with bn.num_batches_tracked

I am trying to pass a set of gpu models to a pool of processes, so something like the example code below:
(torch version 1.0.0.dev20181015)

import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)
import torch
import torch.nn as nn
import time


def run(inpt):
    workers, worker_queue = inpt
    while worker_queue.empty():
        time.sleep(0.1)
    worker_id = worker_queue.get()
    model = workers[worker_id]
    data = torch.rand(5, 10)
    data = data.cuda()
    with torch.no_grad():
        result = model(data)
    worker_queue.put(worker_id)
    return result


def main():
    workers = []
    worker_queue = mp.Manager().Queue()  # mp.Queue()
    for worker_id in range(10):
        worker_model = nn.BatchNorm1d(10)
        worker_model.cuda()
        worker_model.eval()
        workers.append(worker_model)
        worker_queue.put(worker_id)

    pools = mp.Pool(10)

    jobs = [(workers, worker_queue) for _ in range(20)]  # jobs
    result = []
    for res_id, res in enumerate(pools.imap_unordered(run, jobs)):
        result.append(res)
    print(result)


if __name__=="__main__":
    main()

However, when there is BatchNorm, the following error is raised in torch._utils.rebuild_cuda_tensor:

RuntimeError: Expected object of data type 4 but got data type 6 for argument #2 'source'

It is probably caused by the attribute bn.num_batches_tracked, which is supposed to be type 4 (int64_t), but looks like the error message is reporting data type 6 (float). What is likely the cause of this issue? Is there any way to fix it?

1 Like

I posted an issue at https://github.com/pytorch/pytorch/issues/12798