Code extention from Cifar10 to MNIST not working

Hi everyone :),
Problem:
I have a code (that I provide in a short form below) which work perfectly when I run it with CIFAR10. I tried to adapt it to accept MNIST as well but, even after I took care of the differences in the data extention I’m still not able to run it on this second dataset

Brief explanation of the code
I want to merge classes of the dataset MNIST in two groups, let’s say even vs odd; I want the code to be flexible in the sebse that I want it to accept an arbitrary label mapping in input and also be able to set proportions between different classes.
Once I have set my class mapping I proceed selecting the data index associated to element of each group and create a dataloader of each group. I need this step because I have to be able to separate the gradient coming from each group; I know that there are alternative method to achieve this but so far this is the easiest implementation I could find. (params is a dict of parameters defined earlier in the code)
Code

class DatasetMeanStd:
    
    def __init__(self, params):
        """
        This class is a tool to compute mean and std to standardise (or check) your dataset
        the init function load the training and test dataset 
        Parameters
        ----------
        DatasetName : string
        this is a string that encode the dataset that will be used

        Returns
        -------
        None.

        """
        
        params = params.copy()
        self.DatasetName = self.params['Dataset']
        self.ClassesList = self.params['label_map'].keys() #get the list of images in the dataset in case of subset of classes
        self.transform = transforms.ToTensor()
        if(self.DatasetName=='CIFAR10'):
            self.train_data = datasets.CIFAR10(root = self.params['DataFolder'], train = True, download = True, transform = self.transform)
            self.test_data = datasets.CIFAR10(root = self.params['DataFolder'], train = False, download = True, transform = self.transform)
        elif(self.DatasetName=='MNIST'):
            self.train_data = datasets.MNIST(root = self.params['DataFolder'], train = True, download = True, transform = self.transform)
            self.test_data = datasets.MNIST(root = self.params['DataFolder'], train = False, download = True, transform = self.transform) 

                
    def Mean(self):
        """
        Compute the mean of the dataset for the standardization (image vectors normalization)
        Returns
        -------
        list
            mean value for each channel

        """
        
        
        
        if (self.DatasetName == 'CIFAR10'):
            imgs = [item[0] for item in self.train_data if item[1] in self.ClassesList]  # item[0] and item[1] are image and its label
            imgs = torch.stack(imgs, dim=0).numpy()
            
            # calculate mean over each channel (r,g,b)
            mean_r = imgs[:,0,:,:].mean()
            mean_g = imgs[:,1,:,:].mean()
            mean_b = imgs[:,2,:,:].mean()   

            return (mean_r, mean_g, mean_b)

        elif (self.DatasetName == 'MNIST'):
            imgs = [item[0] for item in self.train_data if item[1] in self.ClassesList]  # item[0] and item[1] are image and its label
            imgs = torch.stack(imgs, dim=0).numpy()
            # calculate mean over each channel (r,g,b)
            mean = imgs[:,0,:,:].mean()

            return mean        

    def Std(self):
        
        """
        Compute the std of the dataset for the standardization (image vectors normalization)

        Returns
        -------
        list
            std value for each channel

        """
        
        if (self.DatasetName == 'CIFAR10'):
            imgs = [item[0] for item in self.train_data] # item[0] and item[1] are image and its label
            imgs = torch.stack(imgs, dim=0).numpy()
            
            # calculate std over each channel (r,g,b)
            std_r = imgs[:,0,:,:].std()
            std_g = imgs[:,1,:,:].std()
            std_b = imgs[:,2,:,:].std()  
            
            return(std_r, std_g, std_b)

        elif (self.DatasetName == 'MNIST'):
            imgs = [item[0] for item in self.train_data] # item[0] and item[1] are image and its label
            imgs = torch.stack(imgs, dim=0).numpy()
            
            # calculate std over each channel (r,g,b)
            std = imgs[:,0,:,:].std()
 
            
            return std            
            

LM = {'even_&_odd':{0:0, 1: 1, 2:0, 3:1, 4:0, 5:1, 6:0, 7:1, 8:0, 9:1}} #label mapping 

param['label_map']= LM['even_&_odd']

#MEAN AND STD COMPUTATION FOR NORMALIZATION
DataMean = DatasetMeanStd(params).Mean()
DataStd = DatasetMeanStd(params).Std()
print("the Mean and Std used to standardize data are {} and {}".format(DataMean, DataStd))
  
if (params['Dataset']=='MNIST'):
    transform = transforms.Compose([
                transforms.ToTensor() 
                ,transforms.Normalize(mean=(DataMean,), std=(DataStd,))
                ])   
elif(params['Dataset']=='CIFAR10'):    
    transform = transforms.Compose([
            transforms.ToTensor(), #Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
           
            transforms.Normalize(DataMean, DataStd)]) #in this way we standardize only on the subset of dataset used
            #transforms.Normalize((0.49236655, 0.47394478, 0.41979155), (0.24703233, 0.24348505, 0.26158768))]) 
        
#DATASET DEFINITION

if (params['Dataset']=='MNIST'):
    print('this run used MNIST dataset', file = params['info_file_object'])
    train_data = datasets.MNIST(root = params['DataFolder'], train = True, download = True, transform = transform)
    test_data = datasets.MNIST(root = params['DataFolder'], train = False, download = True, transform = transform)
    valid_data = datasets.MNIST(root = params['DataFolder'], train = False, download = True, transform = transform)
    num_data = len(train_data)
    i=0
    for item in train_data:
        if i==0:
            print(item)
            print('image', item[0].size())
            print('label', item[1])
            print('image', item[0])
            i+=1
    print("total number of samples ", num_data, train_data.data.size())
elif(params['Dataset']=='CIFAR10'):
    print('this run used CIFAR10 dataset', file = params['info_file_object'])
    train_data = datasets.CIFAR10(root = params['DataFolder'], train = True, download = True, transform = transform)
    test_data = datasets.CIFAR10(root = params['DataFolder'], train = False, download = True, transform = transform) 
    valid_data = datasets.CIFAR10(root = params['DataFolder'], train = False, download = True, transform = transform) 
    num_data = len(train_data)
    i=0
    for item in train_data:
        if i==0:
            print(item)
            print('image', item[0].size())
            print('label', item[1])
            print('image', item[0])
            i+=1


#DATALOADER CREATION


            TrainDL = {}#dict to store data loader (one for each mapped class) for train set
            TestDL = {}#dict to store data loader (one for each mapped class) for test set
            ValidDL = {}#dict to store data loader (one for each mapped class) for valid set
            #define the batch sizr for each class such that their proportion will be near to "params['ImabalnceProportions']"
            #the advantage of proceding like that is that we can easly get the exact same number of batches per each class
            if params['OversamplingMode'] == 'OFF':
                #the batch size for each input class
                TrainClassBS = np.rint((params['batch_size']/np.sum(params['ImabalnceProportions']))*np.divide(params['ImabalnceProportions'], params['MappedClassOcc'])).astype(int)
                #the batch size for the whole associated output class given simply by the above expession multiplyed for the occurrences of each output class in the mapping
                TrainTotalClassBS = (np.rint((params['batch_size']/np.sum(params['ImabalnceProportions']))*np.divide(params['ImabalnceProportions'], params['MappedClassOcc'])).astype(int)*(params['MappedClassOcc'])).astype(int)
            elif params['OversamplingMode'] == 'ON':
                TrainClassBS = np.rint((params['batch_size']/model.num_classes)*np.reciprocal(params['MappedClassOcc'])).astype(int)
                TrainTotalClassBS = ((np.rint((params['batch_size']/model.num_classes)*np.reciprocal(params['MappedClassOcc'])).astype(int))*(params['MappedClassOcc'])).astype(int)
            print("real size of the batch size of the training set (after the roundings): {}".format(np.sum(TrainTotalClassBS)),flush=True, file = params['info_file_object']) 
            print("the total sizes of mapped classes are {}".format(self.TrainTotalClassBS))
            
            MajorInputClassBS = np.amax(TrainClassBS) #we select here the class with greater element in the batch; that one will establish the bottle neck for the dataset, we assign to it the maximum possible number of element            

            traintargets = torch.tensor(train_data.targets) #convert label in tensors
            validtargets = torch.tensor(valid_data.targets) #convert label in tensors
            testtargets = torch.tensor(test_data.targets) #convert label in tensors
            #first we cast the target label (originary a list) into a torch tensor
            #we then define a copy of them to avoid issue during the class mapping 
                #can happen for example, using only the above one that I want to map {0:1, 1:0} 
                #we map the 0s in 1 and then map 1s to 0 the list of 1s will include also the 0s mapped in the precedent steps; to avoid so we follow the following rule:
            train_data.targets = torch.tensor(train_data.targets)
            valid_data.targets = torch.tensor(valid_data.targets)
            test_data.targets = torch.tensor(test_data.targets)
                    
            TrainIdx = {}
            ValidIdx = {}
            TestIdx = {}
            for key in params['label_map']:
                print("the batch size for the class {}, mapped in {} is {}".format(key, params['label_map'][key], TrainClassBS[params['label_map'][key]]),flush=True, file = params['info_file_object'])
                #we start collecting the index associated to the output classes togeheter
                #TRAIN
                trainTarget_idx = (traintargets==key).nonzero() 
                #l0=int(900/MajorInputClassBS)*TrainClassBS[params['label_map'][key]] #just for debug purpose
                l0 = int(len(trainTarget_idx)/MajorInputClassBS)*TrainClassBS[params['label_map'][key]] #we first compute the numbers of batches for the majority class and then reproduce for all the others in such a way they will have same number of batches but with a proportion set by TrainClassBS[classcounter-1]            
                #Trainl0 = l0
                #WARNING: LINE ABOVE CHANGED WITH THE ONE BELOW ONLY FOR DEBUG PURPOSE (MNIST ADAPTATION) RESUBSTITUTE WITH THE ONE ABOVE ONCE SOLVED THE ISSUE
                Trainl0 = 5400
                print("the number of elements selected by the class {} loaded on the trainset is {}".format(key, Trainl0),flush=True, file = params['info_file_object'])
                #print(trainTarget_idx)
                ClassTempVar = '%s'%params['label_map'][key]
                
                #VALID
                validTarget_idx = (validtargets==key).nonzero()
                Validl0= 150 #should be less than 500 (since the total test set has 1000 images per class)                
                #TEST
                if (params['ValidMode']=='Test'): #if we are in testing mode we have to repeat it for a third dataset
                    testTarget_idx = (testtargets==key).nonzero()
                    Testl0= 150 #should be less than 500 (since the total test set has 1000 images per class)
                
                if ClassTempVar in TrainIdx: #if the mapped class has already appeared, we concatenate the new indeces to the existing ones
                    TrainIdx['%s'%params['label_map'][key]] = torch.cat((TrainIdx['%s'%params['label_map'][key]], trainTarget_idx[:][0:Trainl0]),0)
                    ValidIdx['%s'%params['label_map'][key]] = torch.cat((ValidIdx['%s'%params['label_map'][key]], validTarget_idx[:][0:Validl0]),0)
                    if (params['ValidMode']=='Test'): #if we are in testing mode we have to repeat it for a third dataset
                        TestIdx['%s'%params['label_map'][key]] = torch.cat((TestIdx['%s'%params['label_map'][key]], testTarget_idx[:][-Testl0:]),0)                   
                else: #if, instead the class is selected for the first time, we simply charge it on the indeces dict
                    TrainIdx['%s'%params['label_map'][key]] = trainTarget_idx[:][0:Trainl0]
                    ValidIdx['%s'%params['label_map'][key]] = validTarget_idx[:][0:Validl0] #select the last indeces for the validation so we don't have overlap increasing the size
                    if (params['ValidMode']=='Test'): #if we are in testing mode we have to repeat it for a third dataset
                        TestIdx['%s'%params['label_map'][key]] = testTarget_idx[:][-Testl0:] #select the last indeces for the validation so we don't have overlap increasing the size
            #REMAP THE LABELS: now that the indexes are fixed we map the dataset to the new labels
            for key in params['label_map']:               
                train_data.targets[traintargets==key]= params['label_map'][key] 
                valid_data.targets[validtargets==key]=params['label_map'][key]
                if (params['ValidMode']=='Test'):
                    test_data.targets[testtargets==key]=params['label_map'][key]
                    
            
            
            #DATALOADER CREATION    

#now we iterate over the mapped classes avoiding repetition
for MC in set(list(params['label_map'].values())): #with this syntax we avoid repetition of same dict items
    #TRAIN
    train_sampler = SubsetRandomSampler(TrainIdx['%s'%MC])  
    #if we are studing the class imbalance case we use the sampler option to select data
    #we load the dataloader corresponding to the mapped "params['label_map'][key]" class as a dict element
    TrainDL['Class%s'%MC] = torch.utils.data.DataLoader(train_data, batch_size = TrainTotalClassBS[MC].item(), 
                                           sampler = train_sampler, num_workers = params['num_workers'])     

    #VALID
    valid_sampler = SubsetRandomSampler(ValidIdx['%s'%MC])
    ValidDL['Class%s'%MC] = torch.utils.data.DataLoader(valid_data, batch_size = params['batch_size'], #note that for test and valid the choice of the batch size is not relevant (we use these dataset only in eval mode)
                                           sampler = valid_sampler, num_workers = params['num_workers']) 
    
    if (params['ValidMode']=='Test'):                    
        test_sampler = SubsetRandomSampler(TestIdx['%s'%MC])                    
        TestDL['Class%s'%MC] = torch.utils.data.DataLoader(test_data, batch_size = params['batch_size'], 
                                               sampler = test_sampler, num_workers = params['num_workers'])     
        
next(iter(self.TrainDL['Class0']))        

The last line, where we try to iterate over one of the defined dataloaders, produce an error, namely

ValueError: too many dimensions: 3>2

What am I doing wrong? What is difference between the two dataset I’m neglecting?