Hi!
I’m using a data loader with sampler, but the dimensions I got are wrong. Could someone please give me some guidance? Thanks so much!
What I’m trying to do:
for task_batch in task_dataloader:
for task in task_batch:
support_image, support_label, query_image, query_label = task
# support_image has shape (shots * ways, channel, height, width)
# support_label has shape (shots * ways,)
# query_image has shape (query * ways, channel, height, width)
# query label has shape (query * ways,)
The problem is, my task_batch ends up being a list of 4 entries, they are
# batch of support_image, shape (batch_size, shots * ways, channel, height, width)
# support_label has shape (batch_size, shots * ways,)
# query_image has shape (batch_size, query * ways, channel, height, width)
# query label has shape (batch_size, query * ways,)
So when I try to unpack the ‘task’, I get the wrong dimensions.
My code is: (I verified that in the dataset classes’ getitem() function, the returned images and labels are the correct shape. I think the problem is the dataloader stacks the up in the wrong way
import numpy as np
import os
import torch
import matplotlib.pyplot as plt
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUMBER_OF_SAMPLES_EA_CLASS = 20
LABEL_NAMES : list[str] = [list of the names of classes–it's quite long']
DATA_FOLDER = "../data/sorted_by_class"
class LearningToBalanaceDataset(torch.utils.data.TensorDataset):
"""a meta learning dataset; items are TASKs"""
def __init__(self, max_number_of_shot : int, number_of_query : int,
way : int, task_dictionary : dict):
super().__init__()
self.way = way
self.max_number_of_shot = max_number_of_shot
self.number_of_query = number_of_query
self.dictionary = task_dictionary
def __getitem__(self, class_indices):
coin = np.random.uniform(low=0, high=1, size=1)
if coin > 0.5:
shot = np.random.choice(range(1, self.max_number_of_shot), size=self.way, replace=True)
else:
shot = np.random.choice(range(1, self.max_number_of_shot), size=1)
shot = shot.repeat(repeats=self.way)
images_support = []
labels_support = []
images_query = []
labels_query = []
for label, class_index in enumerate(class_indices):
total_number_of_examples_in_class = len(self.dictionary[class_index]['image'])
if self.max_number_of_shot + self.number_of_query > total_number_of_examples_in_class:
raise ValueError("LearningToBalanceDataset: shots + query > total sample count!")
actual_shot = shot[label]
total_samples = actual_shot + self.number_of_query
sample_indices = list(np.random.choice(range(total_number_of_examples_in_class), size=total_samples, replace=False))
to_be_padded = self.dictionary[class_index]['image'][sample_indices[:actual_shot]].reshape([-1, 3, 32, 32])
if actual_shot < self.max_number_of_shot:
pad_amount = self.max_number_of_shot - actual_shot
image_support_padded = np.concatenate([to_be_padded,
np.zeros(shape=(pad_amount, to_be_padded.shape[1], to_be_padded.shape[2], to_be_padded.shape[3]))])
label_support_padded = [label] * actual_shot + [0] * pad_amount
else:
image_support_padded = to_be_padded
label_support_padded = [label] * actual_shot
images_support.extend(torch.tensor(image_support_padded, dtype=torch.float32))
labels_support.extend(label_support_padded)
"""count of query images are fixed, no need to pad"""
images_query.extend(torch.tensor(self.dictionary[class_index]['image'][sample_indices[actual_shot:]].reshape([-1, 3, 32, 32]), dtype=torch.float32))
labels_query.extend([label] * self.number_of_query)
images_support = torch.stack(images_support)
labels_support = torch.tensor(labels_support)
images_query = torch.stack(images_query)
labels_query = torch.tensor(labels_query)
return images_support.to(DEVICE), labels_support.to(DEVICE), images_query.to(DEVICE), labels_query.to(DEVICE)
class LearningToBalanceSampler(torch.utils.data.Sampler):
def __init__(self, indices_to_sample_from, way, total_tasks):
super().__init__(None)
self.indices = indices_to_sample_from
self.way = way
self.total_tasks = total_tasks
def __iter__(self):
return (
np.random.default_rng().choice(
self.indices,
size=self.way,
replace=False
) for _ in range(self.total_tasks)
)
def __len__(self):
return self.total_tasks
def get_dataloader(class_names : list[str], max_number_of_shot : int,
number_of_query : int, way : int, total_tasks : int,
batch_size : int):
data_dictionary = {}
l = 0
class_labels = []
for class_name in class_names:
class_data = np.load(f"{DATA_FOLDER}/{class_name}.npy")
class_label = LABEL_NAMES.index(class_name)
class_labels.append(class_label)
l += len(class_data)
data_dictionary[class_label] = {}
data_dictionary[class_label]['image'] = class_data
data_dictionary[class_label]['label'] = np.repeat(class_label, repeats=l)
return torch.utils.data.DataLoader(
dataset=LearningToBalanaceDataset(max_number_of_shot, number_of_query, way, data_dictionary),
sampler=LearningToBalanceSampler(indices_to_sample_from=class_labels, way=way, total_tasks=total_tasks),
drop_last=True,
batch_size=batch_size,
)