I’m using Ubuntu 16.04
Here is the code of my Dataloader.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import h5py
import lmdb
import os
import numpy as np
import random
import torch
import torch.utils.data as data
import multiprocessing
import six
class HybridLoader:
"""
If db_path is a director, then use normal file loading
If lmdb, then load from lmdb
The loading method depend on extention.
"""
def __init__(self, db_path, ext):
self.db_path = db_path
self.ext = ext
if self.ext == '.npy':
self.loader = lambda x: np.load(x)
else:
self.loader = lambda x: np.load(x)['feat']
if db_path.endswith('.lmdb'):
self.db_type = 'lmdb'
self.env = lmdb.open(db_path, subdir=os.path.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
elif db_path.endswith('.pth'): # Assume a key,value dictionary
self.db_type = 'pth'
self.feat_file = torch.load(db_path)
self.loader = lambda x: x
print('HybridLoader: ext is ignored')
else:
self.db_type = 'dir'
def get(self, key):
if self.db_type == 'lmdb':
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(key)
f_input = six.BytesIO(byteflow)
elif self.db_type == 'pth':
f_input = self.feat_file[key]
else:
f_input = os.path.join(self.db_path, key + self.ext)
# load image
feat = self.loader(f_input)
return feat
class DataLoader(data.Dataset):
def reset_iterator(self, split):
del self._prefetch_process[split]
self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
self.iterators[split] = 0
def get_vocab_size(self):
return self.vocab_size
def get_vocab(self):
return self.ix_to_word
def get_seq_length(self):
return self.seq_length
def __init__(self, opt):
self.opt = opt
self.batch_size = self.opt.batch_size
self.seq_per_img = opt.seq_per_img
# feature related options
self.use_fc = getattr(opt, 'use_fc', True)
self.use_att = getattr(opt, 'use_att', True)
self.use_box = getattr(opt, 'use_box', 0)
self.norm_att_feat = getattr(opt, 'norm_att_feat', 0)
self.norm_box_feat = getattr(opt, 'norm_box_feat', 0)
# load the json file which contains additional information about the dataset
print('DataLoader loading json file: ', opt.input_json)
self.info = json.load(open(self.opt.input_json))
if 'ix_to_word' in self.info:
self.ix_to_word = self.info['ix_to_word']
self.vocab_size = len(self.ix_to_word)
print('vocab size is ', self.vocab_size)
# open the hdf5 file
print('DataLoader loading h5 file: ', opt.input_fc_dir, opt.input_att_dir, opt.input_box_dir, opt.input_label_h5)
if self.opt.input_label_h5 != 'none':
self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')
# load in the sequence data
seq_size = self.h5_label_file['labels'].shape
self.label = self.h5_label_file['labels'][:]
self.seq_length = seq_size[1]
print('max sequence length in data is', self.seq_length)
# load the pointers in full to RAM (should be small enough)
self.label_start_ix = self.h5_label_file['label_start_ix'][:]
self.label_end_ix = self.h5_label_file['label_end_ix'][:]
else:
self.seq_length = 1
self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy')
self.att_loader = HybridLoader(self.opt.input_att_dir, '.npz')
self.box_loader = HybridLoader(self.opt.input_box_dir, '.npy')
self.num_images = len(self.info['images']) # self.label_start_ix.shape[0]
print('read %d image features' %(self.num_images))
# separate out indexes for each of the provided splits
self.split_ix = {'train': [], 'val': [], 'test': []}
for ix in range(len(self.info['images'])):
img = self.info['images'][ix]
if not 'split' in img:
self.split_ix['train'].append(ix)
self.split_ix['val'].append(ix)
self.split_ix['test'].append(ix)
elif img['split'] == 'train':
self.split_ix['train'].append(ix)
elif img['split'] == 'val':
self.split_ix['val'].append(ix)
elif img['split'] == 'test':
self.split_ix['test'].append(ix)
elif opt.train_only == 0: # restval
self.split_ix['train'].append(ix)
print('assigned %d images to split train' %len(self.split_ix['train']))
print('assigned %d images to split val' %len(self.split_ix['val']))
print('assigned %d images to split test' %len(self.split_ix['test']))
self.iterators = {'train': 0, 'val': 0, 'test': 0}
self._prefetch_process = {} # The three prefetch process
for split in self.iterators.keys():
self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
# Terminate the child process when the parent exists
def cleanup():
print('Terminating BlobFetcher')
for split in self.iterators.keys():
del self._prefetch_process[split]
import atexit
atexit.register(cleanup)
def get_captions(self, ix, seq_per_img):
# fetch the sequence labels
ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
ix2 = self.label_end_ix[ix] - 1
ncap = ix2 - ix1 + 1 # number of captions available for this image
assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'
if ncap < seq_per_img:
# we need to subsample (with replacement)
seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
for q in range(seq_per_img):
ixl = random.randint(ix1,ix2)
seq[q, :] = self.label[ixl, :self.seq_length]
else:
ixl = random.randint(ix1, ix2 - seq_per_img + 1)
seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
return seq
def get_batch(self, split, batch_size=None):
batch_size = batch_size or self.batch_size
seq_per_img = self.seq_per_img
fc_batch = [] # np.ndarray((batch_size * seq_per_img, self.opt.fc_feat_size), dtype = 'float32')
att_batch = [] # np.ndarray((batch_size * seq_per_img, 14, 14, self.opt.att_feat_size), dtype = 'float32')
label_batch = [] #np.zeros([batch_size * seq_per_img, self.seq_length + 2], dtype = 'int')
wrapped = False
infos = []
gts = []
for i in range(batch_size):
# fetch image
tmp_fc, tmp_att, tmp_seq, \
ix, tmp_wrapped = self._prefetch_process[split].get()
if tmp_wrapped:
wrapped = True
fc_batch.append(tmp_fc)
att_batch.append(tmp_att)
tmp_label = np.zeros([seq_per_img, self.seq_length + 2], dtype = 'int')
if hasattr(self, 'h5_label_file'):
tmp_label[:, 1 : self.seq_length + 1] = tmp_seq
label_batch.append(tmp_label)
# Used for reward evaluation
if hasattr(self, 'h5_label_file'):
gts.append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
else:
gts.append([])
# record associated info as well
info_dict = {}
info_dict['ix'] = ix
info_dict['id'] = self.info['images'][ix]['id']
info_dict['file_path'] = self.info['images'][ix].get('file_path', '')
infos.append(info_dict)
# #sort by att_feat length
# fc_batch, att_batch, label_batch, gts, infos = \
# zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True))
fc_batch, att_batch, label_batch, gts, infos = \
zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True))
data = {}
data['fc_feats'] = np.stack(sum([[_]*seq_per_img for _ in fc_batch], []))
# merge att_feats
max_att_len = max([_.shape[0] for _ in att_batch])
data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, max_att_len, att_batch[0].shape[1]], dtype = 'float32')
for i in range(len(att_batch)):
data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = att_batch[i]
data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32')
for i in range(len(att_batch)):
data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = 1
# set att_masks to None if attention features have same length
if data['att_masks'].sum() == data['att_masks'].size:
data['att_masks'] = None
data['labels'] = np.vstack(label_batch)
# generate mask
nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, data['labels'])))
mask_batch = np.zeros([data['labels'].shape[0], self.seq_length + 2], dtype = 'float32')
for ix, row in enumerate(mask_batch):
row[:nonzeros[ix]] = 1
data['masks'] = mask_batch
data['gts'] = gts # all ground truth captions of each images
data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}
data['infos'] = infos
data = {k:torch.from_numpy(v) if type(v) is np.ndarray else v for k,v in data.items()} # Turn all ndarray to torch tensor
return data
# It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions,
# so that the torch.utils.data.DataLoader can load the data according the index.
# However, it's minimum change to switch to pytorch data loading.
def __getitem__(self, index):
"""This function returns a tuple that is further passed to collate_fn
"""
ix = index #self.split_ix[index]
if self.use_att:
att_feat = self.att_loader.get(str(self.info['images'][ix]['id']))
# Reshape to K x C
att_feat = att_feat.reshape(-1, att_feat.shape[-1])
if self.norm_att_feat:
att_feat = att_feat / np.linalg.norm(att_feat, 2, 1, keepdims=True)
if self.use_box:
box_feat = self.box_loader.get(str(self.info['images'][ix]['id']))
# devided by image width and height
x1,y1,x2,y2 = np.hsplit(box_feat, 4)
h,w = self.info['images'][ix]['height'], self.info['images'][ix]['width']
box_feat = np.hstack((x1/w, y1/h, x2/w, y2/h, (x2-x1)*(y2-y1)/(w*h))) # question? x2-x1+1??
if self.norm_box_feat:
box_feat = box_feat / np.linalg.norm(box_feat, 2, 1, keepdims=True)
att_feat = np.hstack([att_feat, box_feat])
# sort the features by the size of boxes
att_feat = np.stack(sorted(att_feat, key=lambda x:x[-1], reverse=True))
else:
att_feat = np.zeros((1,1,1), dtype='float32')
if self.use_fc:
fc_feat = self.fc_loader.get(str(self.info['images'][ix]['id']))
else:
fc_feat = np.zeros((1), dtype='float32')
if hasattr(self, 'h5_label_file'):
seq = self.get_captions(ix, self.seq_per_img)
else:
seq = None
return (fc_feat,
att_feat, seq,
ix)
def __len__(self):
return len(self.info['images'])
class SubsetSampler(torch.utils.data.sampler.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (list): a list of indices
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in range(len(self.indices)))
def __len__(self):
return len(self.indices)
class BlobFetcher():
"""Experimental class for prefetching blobs in a separate process."""
def __init__(self, split, dataloader, if_shuffle=False):
"""
db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname
"""
self.split = split
self.dataloader = dataloader
self.if_shuffle = if_shuffle
# Add more in the queue
def reset(self):
"""
Two cases for this function to be triggered:
1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
"""
# batch_size is 1, the merge is done in DataLoader class
self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
batch_size=1,
sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]),
shuffle=False,
pin_memory=True,
num_workers=4, # 4 is usually enough
collate_fn=lambda x: x[0]))
def _get_next_minibatch_inds(self):
max_index = len(self.dataloader.split_ix[self.split])
wrapped = False
ri = self.dataloader.iterators[self.split]
ix = self.dataloader.split_ix[self.split][ri]
ri_next = ri + 1
if ri_next >= max_index:
ri_next = 0
if self.if_shuffle:
random.shuffle(self.dataloader.split_ix[self.split])
wrapped = True
self.dataloader.iterators[self.split] = ri_next
return ix, wrapped
def get(self):
if not hasattr(self, 'split_loader'):
self.reset()
ix, wrapped = self._get_next_minibatch_inds()
tmp = self.split_loader.next()
if wrapped:
self.reset()
assert tmp[-1] == ix, "ix not equal"
return tmp + [wrapped]
I’m not sure, but I guess the problem is here:
class BlobFetcher():
"""Experimental class for prefetching blobs in a separate process."""
def __init__(self, split, dataloader, if_shuffle=False):
"""
db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname
"""
self.split = split
self.dataloader = dataloader
self.if_shuffle = if_shuffle
# Add more in the queue
def reset(self):
"""
Two cases for this function to be triggered:
1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
"""
# batch_size is 1, the merge is done in DataLoader class
self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
batch_size=1,
sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]),
shuffle=False,
pin_memory=True,
num_workers=4, # 4 is usually enough
collate_fn=lambda x: x[0]))
def _get_next_minibatch_inds(self):
max_index = len(self.dataloader.split_ix[self.split])
wrapped = False
ri = self.dataloader.iterators[self.split]
ix = self.dataloader.split_ix[self.split][ri]
ri_next = ri + 1
if ri_next >= max_index:
ri_next = 0
if self.if_shuffle:
random.shuffle(self.dataloader.split_ix[self.split])
wrapped = True
self.dataloader.iterators[self.split] = ri_next
return ix, wrapped
def get(self):
if not hasattr(self, 'split_loader'):
self.reset()
ix, wrapped = self._get_next_minibatch_inds()
tmp = self.split_loader.next()
if wrapped:
self.reset()
assert tmp[-1] == ix, "ix not equal"
return tmp + [wrapped]