I have a problem with this code piece
import os
import torch
class Dataset(torch.utils.data.Dataset):
arg = {'batch_size': 1}
def __init__(self, arg):
print('__init__')
self.arg.update(arg)
print(self.arg)
def _worker_init_fn(self, *args):
print('worker init')
print(self.arg)
def get_dataloader(self):
return torch.utils.data.DataLoader(self, batch_size=None,
num_workers=3,
worker_init_fn=self._worker_init_fn,
pin_memory=True,
multiprocessing_context='spawn')
def __getitem__(self, idx):
return 0
def __len__(self):
return 5
def main():
dataloader = Dataset({'batch_size': 2}).get_dataloader()
for _ in dataloader:
pass
if __name__ == '__main__':
main()
Basically I want the workers to have {'batch_size': 2}
, but actually they have {'batch_size': 1}
. The code print the following:
__init__
{'batch_size': 2}
worker init
{'batch_size': 1}
worker init
{'batch_size': 1}
worker init
{'batch_size': 1}
How can I make the workers to have the correct batch_size
?