Dataloader issues with multiprocessing when i do torch.multiprocessing.set_start_method("spawn", force = True)

Does anyone know what this error is?

 Traceback (most recent call last):
 File "./", line 506, in <module>
e = t.train()
File "./", line 188, in train
self.train_generator = iter(self.train_loader)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/", line 278, in __iter__
return _MultiProcessingDataLoaderIter(self)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/", line 682, in __init__
File "/usr/lib/python3.5/multiprocessing/", line 105, in start
self._popen = self._Popen(self)
File "/usr/lib/python3.5/multiprocessing/", line 212, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/usr/lib/python3.5/multiprocessing/", line 274, in _Popen
return Popen(process_obj)
File "/usr/lib/python3.5/multiprocessing/", line 33, in __init__
File "/usr/lib/python3.5/multiprocessing/", line 20, in __init__
File "/usr/lib/python3.5/multiprocessing/", line 48, in _launch
reduction.dump(process_obj, fp)
File "/usr/lib/python3.5/multiprocessing/", line 59, in dump
ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <class 'Environment'>: attribute lookup Environment on builtins failed

My dataloader code is as follows:

from import Dataset, DataLoader
import torch
from torchvision import transforms
import pandas as pd
import numpy as np
from import SubsetRandomSampler
import os
import lmdb
import pyarrow
import lz4framed
from typing import Any
from PIL import Image
from lmdbconverter import convert

class InvalidFileException(Exception):

class BirdDatesetLMDB(Dataset):
def __init__(self, train = True, store_path = 'CUB_200_2011', transform=None):
    tr_te_fn = os.path.join(store_path, 'train_test_split.txt')
    labels = dict([map(int, line.rstrip('\n').split(' ')) for line in open(os.path.join(store_path, 'image_class_labels.txt'))])
    df = pd.read_csv(os.path.join(store_path, 'attributes/images_by_attributes.csv'), index_col=0)
    struct_features = dict(zip(range(1,df.shape[0]+1), df.values))

    bdng_box = dict(
        map(lambda x: (x[0], x[1:]), 
            [list(map(int, map(float, 
                line.rstrip('\n').split(' ')))) 
            for line in open(os.path.join(store_path, 'bounding_boxes.txt'))]))
    self.labels = []
    self.struct = []
    self.bdng_box = []
    tr_te = np.loadtxt(tr_te_fn)
    if train:
        lmdb_store_path = store_path + '/train.db'
        ids = np.where(tr_te[:,1] == 1)[0] + 1
        lmdb_store_path = store_path + '/test.db'
        ids = np.where(tr_te[:,1] == 0)[0] + 1
    for i in range(self.N):
        img_id = ids[i]

        if bdng_box is not None:

    if train:
        if not os.path.isfile(lmdb_store_path):
            convert(store_path + '/train', store_path + '/train.db', write_freq=1000)
        if not os.path.isfile(lmdb_store_path):
            convert(store_path + '/test', store_path + '/test.db', write_freq=1000)

    self.lmdb_store_path = lmdb_store_path
    self.lmdb_connection =,
                                     subdir=False, readonly=True, lock=False, readahead=False, meminit=False)

    with self.lmdb_connection.begin(write=False) as lmdb_txn:
        self.length = lmdb_txn.stat()['entries'] - 1
        self.keys = pyarrow.deserialize(lz4framed.decompress(lmdb_txn.get(b'__keys__')))
        print("Total records: [{}, {}]".format(len(self.keys), self.length))
    self.transform = transform

def __getitem__(self, index):
    lmdb_value = None
    with self.lmdb_connection.begin(write=False) as txn:
        lmdb_value = txn.get(self.keys[index])
    #assert lmdb_value is not None, f"Read empty record for key: {self.keys[index]}"

    img_name, img_arr, img_shape = BirdDatesetLMDB.decompress_and_deserialize(lmdb_value=lmdb_value)
    image = np.frombuffer(img_arr, dtype=np.uint8).reshape(img_shape)
    if image.size == 0:
        raise InvalidFileException("Invalid file found, skipping")
    if self.bdng_box is not None:
        x, y, width, height = self.bdng_box[index]
        image = transforms.functional.resized_crop(Image.fromarray(image), y, x, height, width, (224,224))
    if self.transform:
        image = self.transform(image)

    label = torch.as_tensor(self.labels[index])
    struct = torch.as_tensor(self.struct[index]).float()
    return image, label

def decompress_and_deserialize(lmdb_value: Any):
    return pyarrow.deserialize(lz4framed.decompress(lmdb_value))

def __len__(self):
    return self.length

def read_data(batch_size, valid_size=0.01, num_workers = 2): 
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transformer = transforms.Compose([
        # transforms.RandomResizedCrop(224),
        # transforms.ColorJitter(.4, .4, .4),

test_tfms = transforms.Compose([
        transforms.Resize(int(224 * 1.14)),

trainset = BirdDatesetLMDB(train=True, transform=transformer)
validset = BirdDatesetLMDB(train=True, transform=test_tfms)
testset = BirdDatesetLMDB(train=False, transform=test_tfms)

num_train = len(trainset)
indices = torch.randperm(num_train).tolist()
split = int(np.floor(valid_size * num_train))

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(validset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

return train_loader, valid_loader, test_loader

def birds_collate_fn(data):
# the collate function for dataloader
collate_images = []
collate_labels = []
collate_struct = []

for d in data:
collate_images = torch.stack(collate_images, dim=0)
collate_labels = torch.stack(collate_labels, dim=0)
collate_struct = torch.stack(collate_struct, dim=0)

return {
    #'struct': collate_struct

Based on this comment it seems lmdb creates this issue.

Thank you. is there a way to solve it and still using lmdb? or i need to avoid lmdb?

I tried my code with following dataset which does not use lmdb, and I still have the same issue:

class birds_dataset(Dataset):
def __init__(self, data_dir, image_ids, image_dirs, labels, struct_features, bdng_box=None, transforms=None):
    self.imgs = []
    self.labels = []
    self.struct = []
    self.bdng_box = []

    for i in tqdm.tqdm(range(self.N),'Loading birds data to memory'):
        img_id = image_ids[i]
        img_fn = os.path.join(data_dir, 'images', image_dirs[img_id])

        if bdng_box is not None:

    self.transform = transforms
def __len__(self):
    return self.N

def __getitem__(self, idx):
    img =[idx])
    img = img.convert('RGB')
    if self.bdng_box is not None:
        x, y, width, height = self.bdng_box[idx]
        img = transforms.functional.resized_crop(img, y, x, height, width, (224,224))
    if self.transform is not None:
        img = self.transform(img)
        img = transforms.ToTensor()(img).convert("RGB")
    label = torch.as_tensor(self.labels[idx])
    struct = torch.as_tensor(self.struct[idx]).float()

    return {'c': img, 'b': label, 'a': struct}

I figure out using torch.multiprocessing.set_start_method('spawn') causes the problem. My code runs with no problem on cpu, when i do not set this. However, i believe this is necessary to be set for when i use cuda, right?

Does your code run fine without setting the start method explicitly?
Also, could you wrap your code in the if-clause:

import torch

def main()
    for i, data in enumerate(dataloader):
        # do something here

if __name__ == '__main__':

It runs without problem just on cpu, but not on gpu.

I am following everything you said and I still get the error. I changed the dataset to a one that is not based on lmdb and run everything under the if condition. Anything else comes to your mind, please?

Could you post a code snippet to reproduce this error (you can use random data to feed to the DataLoader)?

Could you fix this? I am facing the same issue with a custom dataset. If I want to use the GPU I have to use the spawn method which breaks the DataLoader.

Hi I’m also meeting this problem. Is there any solution yet?

Unfortunately, I don’t have a code snippet to reproduce the issue, so feel free to provide a minimal executable code snippet.

import math
from typing import Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from import dataset
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        x = x +[:x.size(0)]
        return self.dropout(x)

class TransformerModel(nn.Module):

    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int,
                 nlayers: int, dropout: float = 0.5):
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, ntoken)


    def init_weights(self) -> None:
        initrange = 0.1, initrange)

    def forward(self, src: Tensor) -> Tensor: 
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

            output Tensor of shape [seq_len, batch_size, ntoken]
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src)
        output = self.decoder(output)
        return output

hi I’m receiving the same prediction for whatever input I pass in. Above is my network. Can you kindly help?

Your code snippet isn’t using a DataLoader with multiple workers, so I’m unsure how it’s related to this thread. Could you update your code to show where the DataLoader is failing?

Hi yes

from import Dataset, DataLoader
import numpy as np
"""Behavior Cloning Dataset"""
class BCDataset(Dataset):
  def __init__(self, df=None, transform=None):
    self.ins_and_act_df = df if df else pd.read_csv('instructions_and_actions.csv')
    self.n_samples = self.ins_and_act_df.shape[0]
    self.transform = transform

  def __getitem__(self, index):
    # Obtain text
    encoding_perceiver = tokenizer(self.ins_and_act_df.iloc[index, 1], padding=True, return_tensors="pt")
    inputs, input_mask = encoding_perceiver.input_ids, encoding_perceiver.attention_mask
    outputs_perceiver = model_perceiver(inputs=inputs, attention_mask=input_mask)
    hidden_states_perceiver = outputs_perceiver.hidden_states[-1][:, 0, :].detach()

    # Obtain image
    img ='images/{}.png'.format(index))
    encoding_vit = feature_extractor(images=img, return_tensors="pt")
    pixel_values = encoding_vit['pixel_values']
    outputs_vit = model_vit(pixel_values)
    last_hidden_states = outputs_vit.last_hidden_state[:, 0, :].detach()

    # Obtain labels
    y = torch.from_numpy(np.array(self.ins_and_act_df.iloc[index, 2:], dtype=np.int8))

    if self.transform:
        img = self.transform(img)

    x =, last_hidden_states), axis=1).to(device)
    return torch.nn.functional.normalize(x),

  def __len__(self):
    return self.n_samples

dataset = BCDataset()
train_set, val_set =, [int(dataset.n_samples * 0.8),int(dataset.n_samples * 0.2)])
train_loader = DataLoader(dataset = train_set, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(dataset = val_set, batch_size=4, shuffle=True, num_workers=2)

I don’t see any obvious issues in your code, but it’s also not executable so I won’t be able to try to reproduce and debug the issue.