Here is the collate function that generates data loaders (and receives ALL the data sets):
class GetMetaBatch_NK_WayClassTask:
def __init__(self, meta_batch_size, n_classes, k_shot, k_eval, shuffle=True, pin_memory=True, original=False, flatten=True):
self.meta_batch_size = meta_batch_size
self.n_classes = n_classes
self.k_shot = k_shot
self.k_eval = k_eval
self.shuffle = shuffle
self.pin_memory = pin_memory
self.original = original
self.flatten = flatten
def __call__(self, all_datasets, verbose=False):
NUM_WORKERS = 0 # no need to change
get_data_loader = lambda data_set: iter(data.DataLoader(data_set, batch_size=self.k_shot+self.k_eval, shuffle=self.shuffle, num_workers=NUM_WORKERS, pin_memory=self.pin_memory))
#assert( len(meta_set) == self.meta_batch_size*self.n_classes )
# generate M N,K-way classification tasks
batch_spt_x, batch_spt_y, batch_qry_x, batch_qry_y = [], [], [], []
for m in range(self.meta_batch_size):
n_indices = random.sample(range(0,len(all_datasets)), self.n_classes)
# create N-way, K-shot task instance
spt_x, spt_y, qry_x, qry_y = [], [], [], []
for i,n in enumerate(n_indices):
data_set_n = all_datasets[n]
dataset_loader_n = get_data_loader(data_set_n) # get data set for class n
data_x_n, data_y_n = next(dataset_loader_n) # get all data from current class
spt_x_n, qry_x_n = data_x_n[:self.k_shot], data_x_n[self.k_shot:] # [K, CHW], [K_eval, CHW]
# get labels
if self.original:
#spt_y_n = torch.tensor([n]).repeat(self.k_shot)
#qry_y_n = torch.tensor([n]).repeat(self.k_eval)
spt_y_n, qry_y_n = data_y_n[:self.k_shot], data_y_n[self.k_shot:]
else:
spt_y_n = torch.tensor([i]).repeat(self.k_shot)
qry_y_n = torch.tensor([i]).repeat(self.k_eval)
# form K-shot task for current label n
spt_x.append(spt_x_n); spt_y.append(spt_y_n) # array length N with tensors size [K, CHW]
qry_x.append(qry_x_n); qry_y.append(qry_y_n) # array length N with tensors size [K, CHW]
# form N-way, K-shot task with tensor size [N,W, CHW]
spt_x, spt_y, qry_x, qry_y = torch.stack(spt_x), torch.stack(spt_y), torch.stack(qry_x), torch.stack(qry_y)
# form N-way, K-shot task with tensor size [N*W, CHW]
if verbose:
print(f'spt_x.size() = {spt_x.size()}')
print(f'spt_y.size() = {spt_y.size()}')
print(f'qry_x.size() = {qry_x.size()}')
print(f'spt_y.size() = {qry_y.size()}')
print()
if self.flatten:
CHW = qry_x.shape[-3:]
spt_x, spt_y, qry_x, qry_y = spt_x.reshape(-1, *CHW), spt_y.reshape(-1), qry_x.reshape(-1, *CHW), qry_y.reshape(-1)
## append to N-way, K-shot task to meta-batch of tasks
batch_spt_x.append(spt_x); batch_spt_y.append(spt_y)
batch_qry_x.append(qry_x); batch_qry_y.append(qry_y)
## get a meta-set of M N-way, K-way classification tasks [M,K*N,C,H,W]
batch_spt_x, batch_spt_y, batch_qry_x, batch_qry_y = torch.stack(batch_spt_x), torch.stack(batch_spt_y), torch.stack(batch_qry_x), torch.stack(batch_qry_y)
return batch_spt_x, batch_spt_y, batch_qry_x, batch_qry_y
that is passed to another data loader here:
def get_meta_set_loader(meta_set, meta_batch_size, n_episodes, n_classes, k_shot, k_eval, pin_mem=True, n_workers=4):
"""[summary]
Args:
meta_set ([type]): the meta-set
meta_batch_size ([type]): [description]
n_classes ([type]): [description]
pin_mem (bool, optional): [Since returning cuda tensors in dataloaders is not recommended due to cuda subties with multithreading, instead set pin=True for fast transfering of the data to cuda]. Defaults to True.
n_workers (int, optional): [description]. Defaults to 4.
Returns:
[type]: [description]
"""
if n_classes > len(meta_set):
raise ValueError(f'You really want a N larger than the # classes in the meta-set? n_classes, len(meta_set = {n_classes, len(meta_set)}')
collator_nk_way = GetMetaBatch_NK_WayClassTask(meta_batch_size, n_classes, k_shot, k_eval)
episodic_sampler = EpisodicSampler(total_classes=len(meta_set), n_episodes=n_episodes)
episodic_metaloader = data.DataLoader(
meta_set,
num_workers=n_workers,
pin_memory=pin_mem, # to make moving to cuda more efficient
collate_fn=collator_nk_way, # does the collecting to return M N,K-shot task
batch_sampler=episodic_sampler # for keeping track of the episode
)
return episodic_metaloader
(will generate a smaller example)