Hi,
I originally made a thread here but it seems like this category is more appropriate.
The torchvision docs in this link give an example of using torch.jit.script
with transformations.
While this works well in the single process setting, for some reason it does not seem to work in multi-process settings, such as Distributed Data Parallel.
I am only trying to JIT the data transformation module, not the network itself.
The last line of the error message I get is:
File "/home/user/miniconda3/envs/pytorch/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
RuntimeError: Tried to serialize object __torch__.torch.nn.modules.container.Sequential which does not have a __getstate__ method defined!
Is there any workaround? Or any reference that mentions that JIT is not compatible with multi-processing?
Below I’ve attached a minimal code sample that can reproduce the error:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision.io as io
import torchvision.transforms as transforms
from torch.utils.data import DistributedSampler, DataLoader, Dataset
from tqdm import tqdm
def torch_image_loader(path):
mode = io.ImageReadMode.RGB
return io.read_image(path, mode)
class RGBDataset(Dataset):
def __init__(self, root="./", num_images=100, split='train', img_transforms=None):
self.root = root
self.num_images = num_images
self.split = split
self.img_transforms = img_transforms
self.file_names = ['red.png', 'green.png', 'blue.png']
self.image_loader = torch_image_loader
def __len__(self):
return self.num_images
def __getitem__(self, idx):
color_index = idx % 3
filename = self.file_names[color_index]
img_path = os.path.join(self.root, filename)
img = self.image_loader(img_path)
if self.img_transforms is not None:
img = self.img_transforms(img)
return img, color_index
def create_dataset():
train_augmentations_list = [transforms.ConvertImageDtype(torch.float),
transforms.Resize([224, ])]
train_augmentations = torch.jit.script(torch.nn.Sequential(*train_augmentations_list))
# train_augmentations = torch.nn.Sequential(*train_augmentations_list)
dataset = RGBDataset(img_transforms=train_augmentations)
return dataset
def create_train_dataloader(dataset):
sampler = DistributedSampler(dataset, shuffle=True)
return DataLoader(dataset, batch_size=8, num_workers=4, shuffle=False, pin_memory=True, drop_last=True,
sampler=sampler)
def main_process(gpu, world_size):
rank = gpu
dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank)
torch.cuda.set_device(gpu)
train_dataset = create_dataset()
train_dataloader = create_train_dataloader(train_dataset)
for img, label in tqdm(train_dataloader):
pass
if __name__ == "__main__":
num_available_gpus = torch.cuda.device_count()
world_size = num_available_gpus
os.environ['MASTER_ADDR'] = "localhost"
os.environ['MASTER_PORT'] = str(12345)
mp.spawn(main_process, nprocs=world_size, args=(world_size,))
This assumes there are 3 images named red/green/blue.png
in the same path.
The error seems to occur when the torch.jit.script
is called (the error is raised during the create_dataset
call).
The error does not occur when JIT
is not used.
By the way, the environment I am using is:
Conda Python == 3.9.7
PyTorch == 1.10.1
torchvision==0.11.2