Pytorch multiprocessing dataloader worker_init_fn problem

I have a problem with this code piece

import os

import torch


class Dataset(torch.utils.data.Dataset):
    arg = {'batch_size': 1}
    def __init__(self, arg):
        print('__init__')
        self.arg.update(arg)
        print(self.arg)

    def _worker_init_fn(self, *args):
        print('worker init')
        print(self.arg)

    def get_dataloader(self):
        return torch.utils.data.DataLoader(self, batch_size=None,
                                         num_workers=3,
                                         worker_init_fn=self._worker_init_fn,
                                         pin_memory=True,
                                         multiprocessing_context='spawn')

    def __getitem__(self, idx):
        return 0

    def __len__(self):
        return 5


def main():
    dataloader = Dataset({'batch_size': 2}).get_dataloader()
    for _ in dataloader:
        pass


if __name__ == '__main__':
    main()

Basically I want the workers to have {'batch_size': 2}, but actually they have {'batch_size': 1}. The code print the following:

__init__
{'batch_size': 2}
worker init
{'batch_size': 1}
worker init
{'batch_size': 1}
worker init
{'batch_size': 1}

How can I make the workers to have the correct batch_size?

Solved. If I add self.arg = self.arg after this line it works.

Set multiprocessing_context to fork works fine on linux.

I guess the reason is that the arg is a class variable (rather than an instance variable) and mutuable. As spawn pickles instances to use in new process, the old class variable is inaccessible in dataloader workers. And it makes sense that multiprocessing_context=fork works fine, for forked processes share some resources.