PyTorch: Dataloader() creates a new dimension when creating batches

I am seeing that when looping over the my Dataloader() obect using enumerate() I am getting a new dimension that is being coerced in order to create the batches of my data.

I have 4 Tensors that I am slicing at a macro level (I am panel data so I slice the data in blocks of individuals instead of rows (or observations)):

  • X (3D)
  • Y (2D)
  • Z (2D)
  • id (2D).

In the data I have 10 observations but only 5 individuals on the sample (hence, each individual has 2 observations) on my dataset. Thus, each batch on my data has a minimum of two observations.

Since I am setting the batch_size = 2, I am taking 4 observations for the first and second batch, and only 2 for the third.

This behavior is represented in the output below:

Selection of the data for by __getitem__ for individual 1
torch.Size([2, 3, 3]) X_batch when selecting for ind 1
torch.Size([2, 3]) Z_batch when selecting for ind 1
torch.Size([2, 1]) Y_batch when selecting for ind 1


Selection of the data for by __getitem__ for individual 2
torch.Size([2, 3, 3]) X_batch when selecting for ind 2
torch.Size([2, 3]) Z_batch when selecting for ind 2
torch.Size([2, 1]) Y_batch when selecting for ind 2


Data of the Batch #  1 inside the enumerate
shape X (outside foo) torch.Size([2, 2, 3, 3]) # <<-- here I have a new dimension
shape Z (outside foo) torch.Size([2, 2, 3])
shape Y (outside foo) torch.Size([2, 2, 1])


Selection of the data for by __getitem__ for individual 3
torch.Size([2, 3, 3]) X_batch when selecting for ind 3
torch.Size([2, 3]) Z_batch when selecting for ind 3
torch.Size([2, 1]) Y_batch when selecting for ind 3


Selection of the data for by __getitem__ for individual 4
torch.Size([2, 3, 3]) X_batch when selecting for ind 4
torch.Size([2, 3]) Z_batch when selecting for ind 4
torch.Size([2, 1]) Y_batch when selecting for ind 4


Data of the Batch #  2 inside the enumerate 
shape X (outside foo) torch.Size([2, 2, 3, 3])  # <<-- here I have a new dimension
shape Z (outside foo) torch.Size([2, 2, 3])
shape Y (outside foo) torch.Size([2, 2, 1])


Selection of the data for by __getitem__ for individual 5
torch.Size([2, 3, 3]) X_batch when selecting for ind 5
torch.Size([2, 3]) Z_batch when selecting for ind 5
torch.Size([2, 1]) Y_batch when selecting for ind 5


Data of the Batch #  3 inside the enumerate 
shape X (outside foo) torch.Size([1, 2, 3, 3]) # <<-- here I have a new dimension
shape Z (outside foo) torch.Size([1, 2, 3])
shape Y (outside foo) torch.Size([1, 2, 1])

First, I select the data that corresponds to the first and second individual but inside of the enumerate() loop I am getting a new dimension ([0]) which python is using to put the blocks if individuals.


So here is my question:

Is there any way of concatening torch.cat(, axis = 0) the blocks of data instead of creating this new dimension in order to store the entire batch of data?

So for instance for the first individual I want the following

Data of the Batch #  1 inside the enumerate
shape X (outside foo) torch.Size([4, 3, 3]) # <<-- here I torch.concat(,axis = 0)
shape Z (outside foo) torch.Size([4, 3])
shape Y (outside foo) torch.Size([4, 1])

The code that produces the output below is listed at the end. Thank you


Sample data

import torch
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import argparse

# args to be passed to the model
parser = argparse.ArgumentParser(description='Neural network for Flexible utility (VOT =f(z))')
args = parser.parse_args("") 
args.J = 3 # number of alternatives

# Sample data
X =  pd.DataFrame.from_dict({'x1_1': {0: -0.1766214634108258, 1: 1.645852185286492, 2: -0.13348860101031038, 3: 1.9681043689968933, 4: -1.7004428240831382, 5: 1.4580091413853749, 6: 0.06504113741068565, 7: -1.2168493676768384, 8: -0.3071304478616376, 9: 0.07121332925591593}, 'x1_2': {0: -2.4207773498298844, 1: -1.0828751040719462, 2: 2.73533787008624, 3: 1.5979611987152071, 4: 0.08835542172064115, 5: 1.2209786277076156, 6: -0.44205979195950784, 7: -0.692872860268244, 8: 0.0375521181289943, 9: 0.4656030062266639}, 'x1_3': {0: -1.548320898226322, 1: 0.8457342014424675, 2: -0.21250514722879738, 3: 0.5292389938329516, 4: -2.593946520223666, 5: -0.6188958526077123, 6: 1.6949245117526974, 7: -1.0271341091035742, 8: 0.637561891142571, 9: -0.7717170035055559}, 'x2_1': {0: 0.3797245517345564, 1: -2.2364391598508835, 2: 0.6205947900678905, 3: 0.6623865847688559, 4: 1.562036259999875, 5: -0.13081282910947759, 6: 0.03914373833251773, 7: -0.995761652421108, 8: 1.0649494418154162, 9: 1.3744782478849122}, 'x2_2': {0: -0.5052556836786106, 1: 1.1464291788297152, 2: -0.5662380273138174, 3: 0.6875729143723538, 4: 0.04653136473130827, 5: -0.012885303852347407, 6: 1.5893672346098884, 7: 0.5464286050059511, 8: -0.10430829457707284, 9: -0.5441755265313813}, 'x2_3': {0: -0.9762973303149007, 1: -0.983731467806563, 2: 1.465827578266328, 3: 0.5325950414202745, 4: -1.4452121324204903, 5: 0.8148816373643869, 6: 0.470791989780882, 7: -0.17951636294180473, 8: 0.7351814781280054, 9: -0.28776723200679066}, 'x3_1': {0: 0.12751822396637064, 1: -0.21926633684030983, 2: 0.15758799357206943, 3: 0.5885412224632464, 4: 0.11916562911189271, 5: -1.6436210334529249, 6: -0.12444368631987467, 7: 1.4618564171802453, 8: 0.6847234328916137, 9: -0.23177118858569187}, 'x3_2': {0: -0.6452955690715819, 1: 1.052094761527654, 2: 0.20190339195326157, 3: 0.6839430295237913, 4: -0.2607691613858866, 5: 0.3315513026670213, 6: 0.015901139336566113, 7: 0.15243420084881903, 8: -0.7604225072161022, 9: -0.4387652927008854}, 'x3_3': {0: -1.067058994377549, 1: 0.8026914180717286, 2: -1.9868531745912268, 3: -0.5057770735303253, 4: -1.6589569342151713, 5: 0.358172252880764, 6: 1.9238983803281329, 7: 2.2518318810978246, 8: -1.2781475121874357, 9: -0.7103081175166167}})
Y = pd.DataFrame.from_dict({'CHOICE': {0: 1.0, 1: 1.0, 2: 2.0, 3: 2.0, 4: 3.0, 5: 2.0, 6: 1.0, 7: 1.0, 8: 2.0, 9: 2.0}})
Z = pd.DataFrame.from_dict({'z1': {0: 2.4196730570917233, 1: 2.4196730570917233, 2: 2.822802255159467, 3: 2.822802255159467, 4: 2.073171091633643, 5: 2.073171091633643, 6: 2.044165101485163, 7: 2.044165101485163, 8: 2.4001241292606275, 9: 2.4001241292606275}, 'z2': {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.0, 5: 1.0, 6: 1.0, 7: 1.0, 8: 0.0, 9: 0.0}, 'z3': {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 2.0, 5: 2.0, 6: 2.0, 7: 2.0, 8: 3.0, 9: 3.0}})
id = pd.DataFrame.from_dict({'id_choice': {0: 1.0, 1: 2.0, 2: 3.0, 3: 4.0, 4: 5.0, 5: 6.0, 6: 7.0, 7: 8.0, 8: 9.0, 9: 10.0}, 'id_ind': {0: 1.0, 1: 1.0, 2: 2.0, 3: 2.0, 4: 3.0, 5: 3.0, 6: 4.0, 7: 4.0, 8: 5.0, 9: 5.0}} )
# Create a dataframe with all the data 
data = pd.concat([id,X, Z, Y], axis=1)

Defining the torch.utils.data.Dataset()

# class to create a dataset for choice data
class ChoiceDataset_all(Dataset):
    '''
    Dataset for choice data

    Args:
        data (pandas dataframe): dataframe with all the data

    Returns:
        dictionary with the data for each individual

    '''
    def __init__(self, data,  args , id_variable:str = "id_ind" ):

        if id_variable not in data.columns:
            raise ValueError(f"Variable {id_variable} not in dataframe")
        
        self.data = data
        # select cluster variable
        self.cluster_ids = self.data[id_variable].unique()
        self.Y = torch.LongTensor(self.data['CHOICE'].values -1).reshape(len(self.data['CHOICE'].index),1)
        self.id = torch.LongTensor(self.data[id_variable].values).reshape(len(self.data[id_variable].index),1)
        # number of individuals (N_n)
        self.N_n = torch.unique(self.id).shape[0]
        # number of choices made per individual (t_n)
        _ , self.t_n = self.id.unique(return_counts=True)
        #total number of observations (N_t = total number of choices)
        self.N_t = self.t_n.sum(axis=0).item() 
        # Select regressors: variables that start with "x"
        self.X_wide = data.filter(regex='^x') 
        # turn X_wide into a tensor
        self.X = torch.DoubleTensor(self.X_wide.values)
        # number of regressors (K)
        self.K = int(self.X_wide.shape[1] / args.J)
        # reshape X to have the right dimensions
        # Select variables that start with "z"
        self.Z = torch.DoubleTensor(self.data.filter(regex='^z').values)

    def __len__(self):
        return self.N_n # number of individuals

    def __getitem__(self, idx):
        # select the index of the individual
        self.index = torch.where(self.id == idx+1)[0]
        self.len_batch =  self.index.shape[0] 
        # Select observations for the individual
        Y_batch = self.Y[self.index]
        Z_batch = self.Z[self.index]
        id_batch = self.id[self.index]
        X_batch = self.X[self.index]
        # reshape X_batch to have the right dimensions
        X_batch = X_batch.reshape(self.len_batch,self.K,args.J)
        print("\n")
        print("Selection of the data for by __getitem__ for individual", idx+1)
        print(X_batch.shape, "X_batch when selecting for ind", idx+1)
        print(Z_batch.shape, "Z_batch when selecting for ind", idx+1)
        print(Y_batch.shape, "Y_batch when selecting for ind", idx+1)
        #print(id_batch.shape, "id_batch when selecting for ind", idx+1)
        return {'X': X_batch, 'Z': Z_batch, 'Y': Y_batch, 'id': id_batch}        



Looping over torch.utils.data.DataLoader()

choice_data = ChoiceDataset_all(data, args, id_variable="id_ind")
data_loader = DataLoader(choice_data, batch_size=2, shuffle=False, num_workers=0, drop_last=False)

for idx, data_dict in enumerate(data_loader):
    print("\n")
    print("Data of the Batch # ", idx+1, "inside the enumerate")
    print("shape X (outside foo)", data_dict['X'].shape)
    print("shape Z (outside foo)", data_dict['Z'].shape)
    print("shape Y (outside foo)", data_dict['Y'].shape)
#    print("shape id (outside foo)", data_dict['id'])

The additional dimension created by the DataLoader is the batch dimension which contains the batch_size samples (if possible).
If you want to move dim1 into this dimension (dim0) you could apply a view operation inside the training loop as: x = x.view(-1, x.size(2), x.size(3)). Let me know if this would work for you.

Thank you for your answer @ptrblck. This is certainly an option to solve the problem. However, I just got the suggestion to use a customized collate_fn here. Would rather use a customized collate_fn or just use your approach of transforming the tensors while looping? Also, can you maybe foresee any speed differences between the two approaches? Thank you again.

Just to be concrete. Solved the problem by using the following custom function:


def cust_collate(batch_dict):
    '''
    Collate function that concatenates the data for each individual using axis 0
    Args:
        batch_dict (dict): dictionary with the data for each individual
    Returns:
        batch_dict (dict): dictionary with the concatenated (axis = 0 ) data for each individual
    '''
    # concatenate the data for each individual using axis 0
    for i in range(len(batch_dict)):
        if i == 0:
            X = batch_dict[i]['X']
            Y = batch_dict[i]['Y']
            Z = batch_dict[i]['Z']
            id = batch_dict[i]['id']
        else:
            X = torch.cat((X, batch_dict[i]['X']), axis=0)
            Y = torch.cat((Y, batch_dict[i]['Y']), axis=0)
            Z = torch.cat((Z, batch_dict[i]['Z']), axis=0)
            id = torch.cat((id, batch_dict[i]['id']), axis=0)
    return {'X': X, 'Z': Z, 'Y': Y, 'id': id}

choice_data = ChoiceDataset_all(data, args, id_variable="id_ind")
data_loader = DataLoader(choice_data, batch_size=2, 
                         shuffle=False, num_workers=0, 
                         drop_last=False, 
                         collate_fn=cust_collate)

for idx, data_dict in enumerate(data_loader):
    print("\n")
    print("Data of the Batch # ", idx+1, "inside the enumerate")
    print("shape X (outside foo)", data_dict['X'].shape)
    print("shape Z (outside foo)", data_dict['Z'].shape)
    print("shape Y (outside foo)", data_dict['Y'].shape)

which you can check below produces the correct dimensions I was searching for:



Selection of the data for by __getitem__ for individual 1
torch.Size([2, 3, 3]) X_batch when selecting for ind 1
torch.Size([2, 3]) Z_batch when selecting for ind 1
torch.Size([2, 1]) Y_batch when selecting for ind 1


Selection of the data for by __getitem__ for individual 2
torch.Size([2, 3, 3]) X_batch when selecting for ind 2
torch.Size([2, 3]) Z_batch when selecting for ind 2
torch.Size([2, 1]) Y_batch when selecting for ind 2
shape X (outside foo) torch.Size([4, 3, 3])
shape Z (outside foo) torch.Size([4, 3])
shape Y (outside foo) torch.Size([4, 1])


Selection of the data for by __getitem__ for individual 3
torch.Size([2, 3, 3]) X_batch when selecting for ind 3
torch.Size([2, 3]) Z_batch when selecting for ind 3
torch.Size([2, 1]) Y_batch when selecting for ind 3


Selection of the data for by __getitem__ for individual 4
torch.Size([2, 3, 3]) X_batch when selecting for ind 4
torch.Size([2, 3]) Z_batch when selecting for ind 4
torch.Size([2, 1]) Y_batch when selecting for ind 4
shape X (outside foo) torch.Size([4, 3, 3])
shape Z (outside foo) torch.Size([4, 3])
shape Y (outside foo) torch.Size([4, 1])


Selection of the data for by __getitem__ for individual 5
torch.Size([2, 3, 3]) X_batch when selecting for ind 5
torch.Size([2, 3]) Z_batch when selecting for ind 5
torch.Size([2, 1]) Y_batch when selecting for ind 5
shape X (outside foo) torch.Size([2, 3, 3])
shape Z (outside foo) torch.Size([2, 3])
shape Y (outside foo) torch.Size([2, 1])

Would you have done this somewhat differently?

Best regards and thank you again!

I would assume iterating the batch and concatenating into temporary tensors would be slower than a simple view operation as the latter does not create a copy of the data.
However, depending on your use case and if the data processing is done fast enough in the background you might not see an actual speed difference (but could of course profile the code to double check).

1 Like

Thank you for this insight. I profiled both solutions, and below you can see the results. As you mentioned, a simple view while looping was much faster than iterating the batch using my customed collated_fn.

However, another more efficient collated_fn function, I suspect, might outperform the view-while looping approach, otherwise what’s the purpose of having collated_fn, right? or am I missing something here? Thank you again (AS ALWAYS!)

#%%
import torch
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import argparse
import cProfile
import pandas as pd
import torch


# class to create a dataset for choice data
class ChoiceDataset_all(Dataset):
    '''
    Dataset for choice data

    Args:
        data (pandas dataframe): dataframe with all the data

    Returns:
        dictionary with the data for each individual

    '''
    def __init__(self, data,  args , id_variable:str = "id_ind" ):

        if id_variable not in data.columns:
            raise ValueError(f"Variable {id_variable} not in dataframe")
        
        self.data = data
        self.args = args
        # select cluster variable
        self.cluster_ids = self.data[id_variable].unique()
        self.Y = torch.LongTensor(self.data['CHOICE'].values -1).reshape(len(self.data['CHOICE'].index),1)
        self.id = torch.LongTensor(self.data[id_variable].values).reshape(len(self.data[id_variable].index),1)
        # number of individuals (N_n)
        self.N_n = torch.unique(self.id).shape[0]
        # number of choices made per individual (t_n)
        _ , self.t_n = self.id.unique(return_counts=True)
        #total number of observations (N_t = total number of choices)
        self.N_t = self.t_n.sum(axis=0).item() 
        # Select regressors: variables that start with "x"
        self.X_wide = data.filter(regex='^x') 
        # turn X_wide into a tensor
        self.X = torch.DoubleTensor(self.X_wide.values)
        # number of regressors (K)
        self.K = int(self.X_wide.shape[1] / self.args.J)
        # reshape X to have the right dimensions
        # Select variables that start with "z"
        self.Z = torch.DoubleTensor(self.data.filter(regex='^z').values)

    def __len__(self):
        return self.N_n # number of individuals

    def __getitem__(self, idx):
        # select the index of the individual
        self.index = torch.where(self.id == idx+1)[0]
        self.len_batch =  self.index.shape[0] 
        #print("len_batch", self.len_batch)
        # Select observations for the individual
        Y_batch = self.Y[self.index]
        Z_batch = self.Z[self.index]
        id_batch = self.id[self.index]
        X_batch = self.X[self.index]
        # reshape X_batch to have the right dimensions
        X_batch = X_batch.reshape(self.len_batch,self.K,self.args.J)
        return {'X': X_batch, 'Z': Z_batch, 'Y': Y_batch, 'id': id_batch}        


# Extend this function to handle batch of tensors
def cust_collate(batch_dict):
    '''
    Collate function that concatenates the data for each individual using axis 0
    Args:
        batch_dict (dict): dictionary with the data for each individual
    Returns:
        batch_dict (dict): dictionary with the concatenated (axis = 0 ) data for each individual
    '''
    # concatenate the data for each individual using axis 0
    for i in range(len(batch_dict)):
        if i == 0:
            X = batch_dict[i]['X']
            Y = batch_dict[i]['Y']
            Z = batch_dict[i]['Z']
            id = batch_dict[i]['id']
        else:
            X = torch.cat((X, batch_dict[i]['X']), axis=0)
            Y = torch.cat((Y, batch_dict[i]['Y']), axis=0)
            Z = torch.cat((Z, batch_dict[i]['Z']), axis=0)
            id = torch.cat((id, batch_dict[i]['id']), axis=0)
    return {'X': X, 'Z': Z, 'Y': Y, 'id': id}


def resize_batch_3D_tensor(x:torch.Tensor):
    return  x.view(-1, x.size(2), x.size(3))

def resize_batch_2D_tensor(y:torch.Tensor):
    return  y.view(-1,  y.size(2))

#%% profiling the code



# Create a clustered dataset 

# args to be passed to the model
parser = argparse.ArgumentParser(description='Neural network for Flexible utility (VOT =f(z))')
args = parser.parse_args("") 
args.J = 3 # number of alternatives

# Data parameters
N = 10000
T = 500
N_t = N * T
K = 50
J = 3
# Create data
x = torch.randn((N_t, K * J ))
y = torch.randn((N_t,)).reshape(N_t,1)
# create correlative sequence of numbers using torch.arange with T observations per individual N
id_ind = torch.arange(1, N+1, 1).repeat(T).sort()[0].reshape(N_t,1)
data_sim = pd.DataFrame(torch.cat((x, y, id_ind), axis=1).numpy())
data_sim.columns = [f'x{i}' for i in range(K *J)] + ['CHOICE', 'id_ind']

# Defining the dataset
choice_data = ChoiceDataset_all(data_sim, args, id_variable="id_ind")
batch_size = 100
# Defining the dataloaders
DL_default_collate = DataLoader(choice_data, batch_size=batch_size, shuffle=False, num_workers=0, drop_last=False)
DL_custom_collate = DataLoader(choice_data, batch_size=batch_size, collate_fn=cust_collate, shuffle=False, num_workers=0, drop_last=False)

# Functions to profile
def foo_custom(DL):
    for idx, data_dict in enumerate(DL):
        X = data_dict['X']
        Y = data_dict['Y']
    return True


def foo_default(DL):
    for idx, data_dict in enumerate(DL):
        X = resize_batch_3D_tensor(data_dict['X'])
        Y = resize_batch_2D_tensor(data_dict['Y'])

    return True

Profiling results

custom_prof  = cProfile.run('foo_custom(DL_custom_collate)', sort='cumtime') 
# >> 71344 function calls in 151.814 seconds
default_prof = cProfile.run('foo_default(DL_default_collate)', sort='cumtime') 
# >> 35144 function calls (34744 primitive calls) in 67.471 seconds

I don’t fully understand this statement. The collate_fn is used to create a single batched tensor from a list of samples. The view operation could be applied in a custom collate_fn or inside the DataLoader loop and I would not expect to see any difference since it’s a simple manipulation of the tensor’s meta data.

You are right, @ptrblck. I just noticed that default_collate() is a function on its own, so you can basically integrate your approach inside a custom_fn() like the code below. Thank you again for your help!!


def cust_collate(batch_dict):
    '''
    This function is used to concatenate the data for each individual.
    It relies on the function `default_collate()` to concatenate the
    batches of each block of individuals. Later it resizes the tensors
    to concatenate the data for each individual using axis 0.
    
    Parameters
    ----------
    batch_dict : dict
        Dictionary containing the data for each block of individuals.
        Keys are 'X', 'Y', 'Z' and 'id'
    Returns
    -------
    collate_batche : dict
        Dictionary containing the data for each individual.
        Keys are 'X', 'Y', 'Z' and 'id' (again)
    '''

    def resize_batch_from_4D_to_3D_tensor(x:torch.Tensor):
        """
            This function suppresses the extra dimension created by 
            `default_collate()` and concatenates the data for each 
            individual using axis 0 guesing (`-1`) dimension zero.
        Parameters
        ----------
        x : torch.Tensor (4D)
        Returns
        -------
        torch.Tensor (3D)                
        """
        return  x.view(-1, x.size(2), x.size(3))

    def resize_batch_from_3D_to_2D_tensor(y:torch.Tensor):
        """
            This function suppresses the extra dimension created by 
            `default_collate()` and concatenates the data for each 
            individual using axis 0 guesing (`-1`) dimension zero.
        Parameters
        ----------
        x : torch.Tensor (3D)
        Returns
        -------
        torch.Tensor (2D)                
        """
        return  y.view(-1,  y.size(2))    


    # concatenate the data for each individual using an extra axis 0.
    collate_batche = default_collate(batch_dict)


    # Resize the tensors to concatenate the data for each individual using axis 0.
    collate_batche['X'] = resize_batch_from_4D_to_3D_tensor(collate_batche['X'])
    collate_batche['Y'] = resize_batch_from_3D_to_2D_tensor(collate_batche['Y'])
    collate_batche['Z'] = resize_batch_from_3D_to_2D_tensor(collate_batche['Z'])
    collate_batche['id'] = resize_batch_from_3D_to_2D_tensor(collate_batche['id'])
    # Number of individuals in the batch
    collate_batche['N_n_batch'] = torch.unique(collate_batche['id']).shape[0]
    # Total Number of choices sets in the batch     
    collate_batche['N_t_batch'] = collate_batche['Y'].shape[0]
    
    return collate_batche