JIT transforms in the distributed setting not working

Hi,
I originally made a thread here but it seems like this category is more appropriate.

The torchvision docs in this link give an example of using torch.jit.script with transformations.
While this works well in the single process setting, for some reason it does not seem to work in multi-process settings, such as Distributed Data Parallel.

I am only trying to JIT the data transformation module, not the network itself.

The last line of the error message I get is:

File "/home/user/miniconda3/envs/pytorch/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
RuntimeError: Tried to serialize object __torch__.torch.nn.modules.container.Sequential which does not have a __getstate__ method defined!

Is there any workaround? Or any reference that mentions that JIT is not compatible with multi-processing?

Below I’ve attached a minimal code sample that can reproduce the error:

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision.io as io
import torchvision.transforms as transforms
from torch.utils.data import DistributedSampler, DataLoader, Dataset
from tqdm import tqdm


def torch_image_loader(path):
    mode = io.ImageReadMode.RGB
    return io.read_image(path, mode)


class RGBDataset(Dataset):
    def __init__(self, root="./", num_images=100, split='train', img_transforms=None):
        self.root = root
        self.num_images = num_images
        self.split = split
        self.img_transforms = img_transforms
        self.file_names = ['red.png', 'green.png', 'blue.png']
        self.image_loader = torch_image_loader

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        color_index = idx % 3
        filename = self.file_names[color_index]
        img_path = os.path.join(self.root, filename)
        img = self.image_loader(img_path)
        if self.img_transforms is not None:
            img = self.img_transforms(img)
        return img, color_index


def create_dataset():
    train_augmentations_list = [transforms.ConvertImageDtype(torch.float),
                                transforms.Resize([224, ])]
    train_augmentations = torch.jit.script(torch.nn.Sequential(*train_augmentations_list))
    # train_augmentations = torch.nn.Sequential(*train_augmentations_list)
    dataset = RGBDataset(img_transforms=train_augmentations)
    return dataset


def create_train_dataloader(dataset):
    sampler = DistributedSampler(dataset, shuffle=True)
    return DataLoader(dataset, batch_size=8, num_workers=4, shuffle=False, pin_memory=True, drop_last=True,
                      sampler=sampler)


def main_process(gpu, world_size):
    rank = gpu
    dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=rank)
    torch.cuda.set_device(gpu)

    train_dataset = create_dataset()
    train_dataloader = create_train_dataloader(train_dataset)

    for img, label in tqdm(train_dataloader):
        pass


if __name__ == "__main__":
    num_available_gpus = torch.cuda.device_count()
    world_size = num_available_gpus
    os.environ['MASTER_ADDR'] = "localhost"
    os.environ['MASTER_PORT'] = str(12345)
    mp.spawn(main_process, nprocs=world_size, args=(world_size,))

This assumes there are 3 images named red/green/blue.png in the same path.

The error seems to occur when the torch.jit.script is called (the error is raised during the create_dataset call).
The error does not occur when JIT is not used.

By the way, the environment I am using is:

Conda Python == 3.9.7
PyTorch == 1.10.1
torchvision==0.11.2