Dataloader issues with multiprocessing when i do torch.multiprocessing.set_start_method("spawn", force = True)

Does anyone know what this error is?

 Traceback (most recent call last):
 File "./train.py", line 506, in <module>
e = t.train()
File "./train.py", line 188, in train
self.train_generator = iter(self.train_loader)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 278, in __iter__
return _MultiProcessingDataLoaderIter(self)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 682, in __init__
w.start()
File "/usr/lib/python3.5/multiprocessing/process.py", line 105, in start
self._popen = self._Popen(self)
File "/usr/lib/python3.5/multiprocessing/context.py", line 212, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/usr/lib/python3.5/multiprocessing/context.py", line 274, in _Popen
return Popen(process_obj)
File "/usr/lib/python3.5/multiprocessing/popen_spawn_posix.py", line 33, in __init__
super().__init__(process_obj)
File "/usr/lib/python3.5/multiprocessing/popen_fork.py", line 20, in __init__
self._launch(process_obj)
File "/usr/lib/python3.5/multiprocessing/popen_spawn_posix.py", line 48, in _launch
reduction.dump(process_obj, fp)
File "/usr/lib/python3.5/multiprocessing/reduction.py", line 59, in dump
ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <class 'Environment'>: attribute lookup Environment on builtins failed

My dataloader code is as follows:

from torch.utils.data import Dataset, DataLoader
import torch
from torchvision import transforms
import pandas as pd
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
import os
import lmdb
import pyarrow
import lz4framed
from typing import Any
from PIL import Image
from lmdbconverter import convert


class InvalidFileException(Exception):
pass

class BirdDatesetLMDB(Dataset):
def __init__(self, train = True, store_path = 'CUB_200_2011', transform=None):
    super().__init__()
    
    tr_te_fn = os.path.join(store_path, 'train_test_split.txt')
    labels = dict([map(int, line.rstrip('\n').split(' ')) for line in open(os.path.join(store_path, 'image_class_labels.txt'))])
    df = pd.read_csv(os.path.join(store_path, 'attributes/images_by_attributes.csv'), index_col=0)
    struct_features = dict(zip(range(1,df.shape[0]+1), df.values))

    bdng_box = dict(
        map(lambda x: (x[0], x[1:]), 
            [list(map(int, map(float, 
                line.rstrip('\n').split(' ')))) 
            for line in open(os.path.join(store_path, 'bounding_boxes.txt'))]))
    
    self.labels = []
    self.struct = []
    self.bdng_box = []
    
    tr_te = np.loadtxt(tr_te_fn)
    
    if train:
        lmdb_store_path = store_path + '/train.db'
        ids = np.where(tr_te[:,1] == 1)[0] + 1
    else:
        lmdb_store_path = store_path + '/test.db'
        ids = np.where(tr_te[:,1] == 0)[0] + 1
    
    self.N=len(ids)
    
    for i in range(self.N):
        img_id = ids[i]
        self.labels.append(labels[img_id]-1)
        self.struct.append(struct_features[img_id])

        if bdng_box is not None:
            self.bdng_box.append(bdng_box[img_id])

    
    if train:
        if not os.path.isfile(lmdb_store_path):
            convert(store_path + '/train', store_path + '/train.db', write_freq=1000)
    else:
        if not os.path.isfile(lmdb_store_path):
            convert(store_path + '/test', store_path + '/test.db', write_freq=1000)

    self.lmdb_store_path = lmdb_store_path
    self.lmdb_connection = lmdb.open(lmdb_store_path,
                                     subdir=False, readonly=True, lock=False, readahead=False, meminit=False)

    with self.lmdb_connection.begin(write=False) as lmdb_txn:
        self.length = lmdb_txn.stat()['entries'] - 1
        self.keys = pyarrow.deserialize(lz4framed.decompress(lmdb_txn.get(b'__keys__')))
        print("Total records: [{}, {}]".format(len(self.keys), self.length))
    self.transform = transform

def __getitem__(self, index):
    lmdb_value = None
    with self.lmdb_connection.begin(write=False) as txn:
        lmdb_value = txn.get(self.keys[index])
    #assert lmdb_value is not None, f"Read empty record for key: {self.keys[index]}"

    img_name, img_arr, img_shape = BirdDatesetLMDB.decompress_and_deserialize(lmdb_value=lmdb_value)
    image = np.frombuffer(img_arr, dtype=np.uint8).reshape(img_shape)
    if image.size == 0:
        raise InvalidFileException("Invalid file found, skipping")
    if self.bdng_box is not None:
        x, y, width, height = self.bdng_box[index]
        image = transforms.functional.resized_crop(Image.fromarray(image), y, x, height, width, (224,224))
    if self.transform:
        image = self.transform(image)
        

    label = torch.as_tensor(self.labels[index])
    struct = torch.as_tensor(self.struct[index]).float()
    
    return image, label

@staticmethod
def decompress_and_deserialize(lmdb_value: Any):
    return pyarrow.deserialize(lz4framed.decompress(lmdb_value))

def __len__(self):
    return self.length

def read_data(batch_size, valid_size=0.01, num_workers = 2): 
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transformer = transforms.Compose([
        # transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        # transforms.ColorJitter(.4, .4, .4),
        transforms.ToTensor(),
        normalize,
    ])

test_tfms = transforms.Compose([
        transforms.Resize(int(224 * 1.14)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])


trainset = BirdDatesetLMDB(train=True, transform=transformer)
validset = BirdDatesetLMDB(train=True, transform=test_tfms)
    
testset = BirdDatesetLMDB(train=False, transform=test_tfms)

num_train = len(trainset)
indices = torch.randperm(num_train).tolist()
split = int(np.floor(valid_size * num_train))

    
train_idx, valid_idx = indices[split:], indices[:split]
print(len(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
    
train_loader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(validset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers, pin_memory=True)
    
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

return train_loader, valid_loader, test_loader

def birds_collate_fn(data):
# the collate function for dataloader
collate_images = []
collate_labels = []
collate_struct = []

for d in data:
    collate_images.append(d['c'])
    collate_labels.append(d['b'])
    collate_struct.append((d['a']))
collate_images = torch.stack(collate_images, dim=0)
collate_labels = torch.stack(collate_labels, dim=0)
collate_struct = torch.stack(collate_struct, dim=0)

return {
    #'img': 
        collate_images,
    #'label': 
        collate_labels,
    #'struct': collate_struct
}

Based on this comment it seems lmdb creates this issue.

Thank you. is there a way to solve it and still using lmdb? or i need to avoid lmdb?

I tried my code with following dataset which does not use lmdb, and I still have the same issue:

class birds_dataset(Dataset):
def __init__(self, data_dir, image_ids, image_dirs, labels, struct_features, bdng_box=None, transforms=None):
    self.imgs = []
    self.labels = []
    self.struct = []
    self.bdng_box = []

    self.N=len(image_ids)
    
    for i in tqdm.tqdm(range(self.N),'Loading birds data to memory'):
        img_id = image_ids[i]
        img_fn = os.path.join(data_dir, 'images', image_dirs[img_id])
        self.imgs.append(img_fn)
        self.labels.append(labels[img_id]-1)
        self.struct.append(struct_features[img_id])

        if bdng_box is not None:
            self.bdng_box.append(bdng_box[img_id])

    self.transform = transforms
 
def __len__(self):
    return self.N

def __getitem__(self, idx):
    img = Image.open(self.imgs[idx])
    img = img.convert('RGB')
    if self.bdng_box is not None:
        x, y, width, height = self.bdng_box[idx]
        img = transforms.functional.resized_crop(img, y, x, height, width, (224,224))
    if self.transform is not None:
        img = self.transform(img)
    else:
        img = transforms.ToTensor()(img).convert("RGB")
    
    label = torch.as_tensor(self.labels[idx])
    struct = torch.as_tensor(self.struct[idx]).float()

    return {'c': img, 'b': label, 'a': struct}

I figure out using torch.multiprocessing.set_start_method('spawn') causes the problem. My code runs with no problem on cpu, when i do not set this. However, i believe this is necessary to be set for when i use cuda, right?

Does your code run fine without setting the start method explicitly?
Also, could you wrap your code in the if-clause:

import torch

def main()
    for i, data in enumerate(dataloader):
        # do something here

if __name__ == '__main__':
    main()

It runs without problem just on cpu, but not on gpu.

I am following everything you said and I still get the error. I changed the dataset to a one that is not based on lmdb and run everything under the if condition. Anything else comes to your mind, please?

Could you post a code snippet to reproduce this error (you can use random data to feed to the DataLoader)?