Hi! I am training a classification model on the EMNIST dataset. I noticed big differences in the convergence of the loss when using the EMNIST dataset from torchvision (torchvision.datasets.EMNIST
), instead of loading from CSV. I tested using the same parameters for the model and with the same random seed.
I constructed a custom dataset class to fetch data from csv files:
def load_csv_data(csv_path:str):
#function to load data from csv, used to create img_data & label_data
data_df = pd.read_csv(csv_path)
labels = (data_df.iloc[:,0]).values #first column, labels are uppercase & lowecase, e.g "a" or "A" = 1, shift range from 1-26 to 0-25
raw_pixels = data_df.iloc[:,1:].values #pixel values , transpose to get correct orientation
imgs = raw_pixels.reshape(-1,28,28) #emnist images are 28*28 pixels - 3d ndarray
return imgs, labels
class EMNISTDataset(Dataset):
def __init__(self, img_data:ndarray, label_data:ndarray, transform, target_transform, subset_indices:slice,):
self.img_subset = img_data[subset_indices]
self.label_subset = label_data[subset_indices]
if self.img_subset.shape[0] != self.label_subset.shape[0]:
raise RuntimeError(f"Img subset len does not match label subset len: \n Img subset len: {self.img_subset.shape[0]} \n Label subset len:{self.label_subset.shape[0]}")
self.transform = transform
self.target_transform = target_transform
def __len__(self):
data_len = self.label_subset.shape[0]
return data_len
def __getitem__(self, idx):
X = self.img_subset[idx,:,:]
y = int(self.label_subset[idx])
if self.transform:
X = self.transform(X) #applies torchvision transform
if self.target_transform:
y = self.target_transform(y) #applies torchvision transform
return X, y
I am using these transforms:
img_transform = v2.Compose([
v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]), # to tensor
v2.RandomHorizontalFlip(p=1),#100% probability
v2.RandomRotation(degrees=(90,90)), #flip 90 degrees
target_transform = lambda y:y-1
This is the Crossentropyloss from multiple runs, the worse ones (around 3.2) are from my custom dataset, while the better ones are from the torchvision dataset: