Iterating a Dataloader
on macOS with M1 Apple chip creates multiple processes. The number of process is num_workers + 1
. Each process seems to load the .py
file again, and executes the top-level statement again. The same script does not cause the problem on a Linux server. I guess this might have something to do with the multiprocessing_context
argument of Dataloader
, but I did not find out a solution.
Here is a short test code:
import logging
from torch.utils.data import DataLoader
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
logger = logging.getLogger(name=__name__)
logger.setLevel(level=logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setFormatter(fmt=logging.Formatter(fmt=f'[PID=%(process)d][%(name)s] %(msg)s'))
logger.addHandler(console_handler)
logger.warning('This message will be printed by `num_workers` + 1 processes.')
def main():
data_loader = DataLoader(
dataset=FashionMNIST(root='./data', train=True, transform=ToTensor()),
batch_size=128, shuffle=True, num_workers=4,
)
for _ in data_loader:
i = 1
if __name__ == '__main__':
main()
On macOS M1, it prints:
[PID=71509][__main__] This message will be printed by `num_workers` + 1 processes.
[PID=71513][__mp_main__] This message will be printed by `num_workers` + 1 processes.
[PID=71514][__mp_main__] This message will be printed by `num_workers` + 1 processes.
[PID=71515][__mp_main__] This message will be printed by `num_workers` + 1 processes.
[PID=71516][__mp_main__] This message will be printed by `num_workers` + 1 processes.
and on Linux, it prints:
[PID=71509][__main__] This message will be printed by `num_workers` + 1 processes.