Does anyone know what this error is?
Traceback (most recent call last):
File "./train.py", line 506, in <module>
e = t.train()
File "./train.py", line 188, in train
self.train_generator = iter(self.train_loader)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 278, in __iter__
return _MultiProcessingDataLoaderIter(self)
File "/usr/local/lib/python3.5/dist-packages/torch/utils/data/dataloader.py", line 682, in __init__
w.start()
File "/usr/lib/python3.5/multiprocessing/process.py", line 105, in start
self._popen = self._Popen(self)
File "/usr/lib/python3.5/multiprocessing/context.py", line 212, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/usr/lib/python3.5/multiprocessing/context.py", line 274, in _Popen
return Popen(process_obj)
File "/usr/lib/python3.5/multiprocessing/popen_spawn_posix.py", line 33, in __init__
super().__init__(process_obj)
File "/usr/lib/python3.5/multiprocessing/popen_fork.py", line 20, in __init__
self._launch(process_obj)
File "/usr/lib/python3.5/multiprocessing/popen_spawn_posix.py", line 48, in _launch
reduction.dump(process_obj, fp)
File "/usr/lib/python3.5/multiprocessing/reduction.py", line 59, in dump
ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <class 'Environment'>: attribute lookup Environment on builtins failed
My dataloader code is as follows:
from torch.utils.data import Dataset, DataLoader
import torch
from torchvision import transforms
import pandas as pd
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler
import os
import lmdb
import pyarrow
import lz4framed
from typing import Any
from PIL import Image
from lmdbconverter import convert
class InvalidFileException(Exception):
pass
class BirdDatesetLMDB(Dataset):
def __init__(self, train = True, store_path = 'CUB_200_2011', transform=None):
super().__init__()
tr_te_fn = os.path.join(store_path, 'train_test_split.txt')
labels = dict([map(int, line.rstrip('\n').split(' ')) for line in open(os.path.join(store_path, 'image_class_labels.txt'))])
df = pd.read_csv(os.path.join(store_path, 'attributes/images_by_attributes.csv'), index_col=0)
struct_features = dict(zip(range(1,df.shape[0]+1), df.values))
bdng_box = dict(
map(lambda x: (x[0], x[1:]),
[list(map(int, map(float,
line.rstrip('\n').split(' '))))
for line in open(os.path.join(store_path, 'bounding_boxes.txt'))]))
self.labels = []
self.struct = []
self.bdng_box = []
tr_te = np.loadtxt(tr_te_fn)
if train:
lmdb_store_path = store_path + '/train.db'
ids = np.where(tr_te[:,1] == 1)[0] + 1
else:
lmdb_store_path = store_path + '/test.db'
ids = np.where(tr_te[:,1] == 0)[0] + 1
self.N=len(ids)
for i in range(self.N):
img_id = ids[i]
self.labels.append(labels[img_id]-1)
self.struct.append(struct_features[img_id])
if bdng_box is not None:
self.bdng_box.append(bdng_box[img_id])
if train:
if not os.path.isfile(lmdb_store_path):
convert(store_path + '/train', store_path + '/train.db', write_freq=1000)
else:
if not os.path.isfile(lmdb_store_path):
convert(store_path + '/test', store_path + '/test.db', write_freq=1000)
self.lmdb_store_path = lmdb_store_path
self.lmdb_connection = lmdb.open(lmdb_store_path,
subdir=False, readonly=True, lock=False, readahead=False, meminit=False)
with self.lmdb_connection.begin(write=False) as lmdb_txn:
self.length = lmdb_txn.stat()['entries'] - 1
self.keys = pyarrow.deserialize(lz4framed.decompress(lmdb_txn.get(b'__keys__')))
print("Total records: [{}, {}]".format(len(self.keys), self.length))
self.transform = transform
def __getitem__(self, index):
lmdb_value = None
with self.lmdb_connection.begin(write=False) as txn:
lmdb_value = txn.get(self.keys[index])
#assert lmdb_value is not None, f"Read empty record for key: {self.keys[index]}"
img_name, img_arr, img_shape = BirdDatesetLMDB.decompress_and_deserialize(lmdb_value=lmdb_value)
image = np.frombuffer(img_arr, dtype=np.uint8).reshape(img_shape)
if image.size == 0:
raise InvalidFileException("Invalid file found, skipping")
if self.bdng_box is not None:
x, y, width, height = self.bdng_box[index]
image = transforms.functional.resized_crop(Image.fromarray(image), y, x, height, width, (224,224))
if self.transform:
image = self.transform(image)
label = torch.as_tensor(self.labels[index])
struct = torch.as_tensor(self.struct[index]).float()
return image, label
@staticmethod
def decompress_and_deserialize(lmdb_value: Any):
return pyarrow.deserialize(lz4framed.decompress(lmdb_value))
def __len__(self):
return self.length
def read_data(batch_size, valid_size=0.01, num_workers = 2):
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transformer = transforms.Compose([
# transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(.4, .4, .4),
transforms.ToTensor(),
normalize,
])
test_tfms = transforms.Compose([
transforms.Resize(int(224 * 1.14)),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
trainset = BirdDatesetLMDB(train=True, transform=transformer)
validset = BirdDatesetLMDB(train=True, transform=test_tfms)
testset = BirdDatesetLMDB(train=False, transform=test_tfms)
num_train = len(trainset)
indices = torch.randperm(num_train).tolist()
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
print(len(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(validset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
return train_loader, valid_loader, test_loader
def birds_collate_fn(data):
# the collate function for dataloader
collate_images = []
collate_labels = []
collate_struct = []
for d in data:
collate_images.append(d['c'])
collate_labels.append(d['b'])
collate_struct.append((d['a']))
collate_images = torch.stack(collate_images, dim=0)
collate_labels = torch.stack(collate_labels, dim=0)
collate_struct = torch.stack(collate_struct, dim=0)
return {
#'img':
collate_images,
#'label':
collate_labels,
#'struct': collate_struct
}