JIT transforms in the distributed setting doesn't work?

The torchvision docs in this link give an example of using torch.jit.script with transformations.
While this works well in the single process setting, for some reason it does not seem to work in multi-process settings, such as Distributed Data Parallel.

I am only trying to JIT the data transformation module, not the network itself.

The last line of the error message I get is:

File "/home/user/miniconda3/envs/pytorch/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
RuntimeError: Tried to serialize object __torch__.torch.nn.modules.container.Sequential which does not have a __getstate__ method defined!

Is there any workaround? Or any reference that mentions that JIT is not compatible with multi-processing?


This is mostly likely happening because you are initializing some of the dataloader/transform components in the parent process and trying to pass them down the child processes. Would it be possible to contain all of your dataloading/transformations in the child processes only?

Also, if you have a minimal repro, that would be helpful in debugging this further.

Hi @pritamdamania87, thanks for your reply.

As far as I know, I do not have anything going from the parent to child processes.
Below I’ve attached a minimal code sample that can reproduce the error:

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision.io as io
import torchvision.transforms as transforms
from torch.utils.data import DistributedSampler, DataLoader, Dataset
from tqdm import tqdm

def torch_image_loader(path):
    mode = io.ImageReadMode.RGB
    return io.read_image(path, mode)

class RGBDataset(Dataset):
    def __init__(self, root="./", num_images=100, split='train', img_transforms=None):
        self.root = root
        self.num_images = num_images
        self.split = split
        self.img_transforms = img_transforms
        self.file_names = ['red.png', 'green.png', 'blue.png']
        self.image_loader = torch_image_loader

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        color_index = idx % 3
        filename = self.file_names[color_index]
        img_path = os.path.join(self.root, filename)
        img = self.image_loader(img_path)
        if self.img_transforms is not None:
            img = self.img_transforms(img)
        return img, color_index

def create_dataset():
    train_augmentations_list = [transforms.ConvertImageDtype(torch.float),
                                transforms.Resize([224, ])]
    train_augmentations = torch.jit.script(torch.nn.Sequential(*train_augmentations_list))
    # train_augmentations = torch.nn.Sequential(*train_augmentations_list)
    dataset = RGBDataset(img_transforms=train_augmentations)
    return dataset

def create_train_dataloader(dataset):
    sampler = DistributedSampler(dataset, shuffle=True)
    return DataLoader(dataset, batch_size=8, num_workers=4, shuffle=False, pin_memory=True, drop_last=True,

def main_process(gpu, world_size):
    rank = gpu
    dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank)

    train_dataset = create_dataset()
    train_dataloader = create_train_dataloader(train_dataset)

    for img, label in tqdm(train_dataloader):

if __name__ == "__main__":
    num_available_gpus = torch.cuda.device_count()
    world_size = num_available_gpus
    os.environ['MASTER_ADDR'] = "localhost"
    os.environ['MASTER_PORT'] = str(12345)
    mp.spawn(main_process, nprocs=world_size, args=(world_size,))

This assumes there are 3 images named red/green/blue.png in the same path.

When using torch.jit.script in the create_dataset function, the error I get is:

RuntimeError: Tried to serialize object __torch__.torch.nn.modules.container.Sequential which does not have a __getstate__ method defined!

The error seems to occur when the torch.jit.script is called (the error is raised during the create_dataset call).
The error does not occur when JIT is not used.

By the way, the environment I am using is:

Conda Python == 3.9.7
PyTorch == 1.10.1

Ah I see, based on your example this seems like primarily an issue with jit and not distributed. We have a separate jit category in the forums for questions like these.

Ok, thanks. Will try again in the other category :slight_smile: