I wrote the following code and it acts really weird. sometimes the next(iter(dataloader)) works well and sometimes it throws an error. I shared the error below. Any thoughts?
def proc_images(data_dir ='flower-data', train = True):
"""
Saves compressed, resized images as HDF5 datsets
Returns
data.h5, where each dataset is an image or class label
e.g. X23,y23 = image and corresponding class label
"""
image_path_list = sorted([os.path.join(data_dir+ '/jpg', filename) for filename in os.listdir(data_dir + '/jpg') if not filename.startswith('.')])
label_dict = loadmat(data_dir + '/config/imagelabels.mat')
label_list = label_dict['labels'][0]
label_list[:] = [i - 1 for i in label_list]
ids = loadmat(data_dir + '/config/setid.mat')
if train:
ids = np.concatenate((ids['trnid'], ids['valid']), 1) - 1
name_data = 'flowers_train_data.h5'
name_labels = 'flowers_train_labels.h5'
else:
ids = ids['tstid'] - 1
name_data = 'flowers_test_data.h5'
name_labels = 'flowers_test_labels.h5'
label_list = [ label_list[index] for index in ids.astype(int).tolist()[0] ]
image_path_list = [ image_path_list[index] for index in ids.astype(int).tolist()[0] ]
with h5py.File(name_data, 'w') as hf:
for i,img in enumerate(image_path_list):
# Images
image = Image.open(img)
(HEIGHT, WIDTH) = image.size
Xset = hf.create_dataset(
name='X'+str(i),
data=image)
print("\r", i, end="")
with h5py.File(name_labels, 'w') as hf:
for i,img in enumerate(image_path_list):
yset = hf.create_dataset(
name='y'+str(i),
data = label_list[i],
shape=(1,),
maxshape=(None,))
print("\r", i, end="")
class FlowersDataset(Dataset):
def __init__(self, train = True, transform=None):
if not train:
if os.path.isfile('flowers_test_data.h5'):
print ("test data exist")
else:
proc_images(data_dir ='flower-data', train = False)
if train:
if os.path.isfile('flowers_train_data.h5'):
print ("train data exist")
else:
proc_images(data_dir ='flower-data', train = True)
self.transform = transform
if train:
self.hf_data = h5py.File('flowers_train_data.h5', 'r')
self.hf_labels = h5py.File('flowers_train_labels.h5', 'r')
else:
self.hf_data = h5py.File('flowers_test_data.h5', 'r')
self.hf_labels = h5py.File('flowers_test_labels.h5', 'r')
def __len__(self):
return len(self.hf_data)
def __getitem__(self, idx):
image = Image.fromarray(self.hf_data.get('X' + str(idx)).value)
label = self.hf_labels.get('y' + str(idx)).value
if self.transform:
image = self.transform(image)
return image, label
def read_data(batch_size, valid_size=0.01, num_workers = 2):
transform_train = transforms.Compose([transforms.RandomRotation(30),
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
test_valid_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
trainset = FlowersDataset(train=True, transform=transform_train)
validset = FlowersDataset(train=True, transform=test_valid_transforms)
testset = FlowersDataset(train=False, transform=test_valid_transforms)
num_train = len(trainset)
indices = torch.randperm(num_train).tolist()
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
#print(len(valid_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = DataLoader(trainset, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers, pin_memory=True)
valid_loader = DataLoader(validset, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
return train_loader, valid_loader, test_loader
AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File “/anaconda3/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py”, line 178, in _worker_loop
data = fetcher.fetch(index)
File “/anaconda3/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py”, line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File “/anaconda3/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py”, line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File “/Users/abdolghaniebrahimi/Documents/neuron-compression/untitled10.py”, line 102, in getitem
image = Image.fromarray(self.hf_data.get(‘X’ + str(idx)).value)
AttributeError: ‘NoneType’ object has no attribute ‘value’