`forward()` works but `backward()` fails on the same batch

Even though I provide a (rather convoluted but) concrete example below, I have a conceptual question that I haven’t been able to solve reading the documentation on the torch.tensor.backward() function:

My question is: How could it be possible that, for a given batch, I can correctly compute the loss function, but then it fails to compute the gradients using the backward() method?.

In particular, I am getting a conformability error of the form: RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0 only when loss.backward(), but the loss function is computed without errors. The structure described is listed below.

# Defining my model 
model = LL_MODEL(args)
# Training loop
for idx, data_batch  in enumerate(DataLoader_obj):
    # Compute the loss
    loss = model(data_batch)
    # Compute the gradient
    loss.backward() #<- this fails even when the loss is computed fine.

However, what was unexpected to me is that, if it can correctly compute the loss function, then it shouldn’t have problems computing the backward method. So I think that some conceptual reasons for this behavior to happen might be useful for my debugging process. Thank you.


Additionally, before presenting the programs, here there are a couple of things I’ve checked and noticed trying to debug my program:

  1. All the parameters have requires_grad = True.

  2. Dimensions match (because loss is correctly computed)

  3. When using a batch_size = 1 (only taken one individual, two records), the code runs without problems. Hence, that makes me think that the problem might be related to the panel structure of the data. In each batch, I am sampling at the level of groups (individuals = id_ind variable) and not at the level of observations (rows) (see ChoiceDataset() class below). On the data, each individual has a total of two records (two rows per individual). There are 5 individuals, each with 2 observations (10 data points in total).

  4. When using a batch_size= 5 (which is a total number of individuals, and hence is taking the entire data), the program also fails, and the error that is shown is RuntimeError: The size of tensor a (10) must match the size of tensor b (5) at non-singleton dimension 0. Where 10 is the total number of rows on the example data, and 5 is the number of individuals, which is also a flag pointing towards the panel structure of the data as the source of the bug. However, the loss function is computed correctly in all of the cases, regardless of the batch_size.


Finally, here is the code that replicates the described error.

Sample data:

import pandas as pd
import argparse
# args to be passed to the model
parser = argparse.ArgumentParser(description='')
parser.add_argument('--R', type=int, default=7, help='Number of draws (default: 100)')
args = parser.parse_args("") 
args.J = 3 # number of alternatives
args.batch_size = 2 # length of the batch
args.K = 3 # number of alternatives
args.K_r = 1 # rand par
args.K_f = 2 # fixed par
X =  pd.DataFrame.from_dict({'x1_1': {0: -0.176, 1: 1.6458, 2: -0.13, 3: 1.96, 4: -1.70, 5: 1.45, 6: 0.06, 7: -1.21, 8: -0.30, 9: 0.07}, 'x1_2': {0: -2.420, 1: -1.0828, 2: 2.73, 3: 1.597, 4: 0.088, 5: 1.220, 6: -0.44, 7: -0.692, 8: 0.037, 9: 0.465}, 'x1_3': {0: -1.5483, 1: 0.8457, 2: -0.21250, 3: 0.52923, 4: -2.593, 5: -0.6188, 6: 1.69, 7: -1.027, 8: 0.63, 9: -0.771}, 'x2_1': {0: 0.379724, 1: -2.2364391598508835, 2: 0.6205947900678905, 3: 0.6623865847688559, 4: 1.562036259999875, 5: -0.13081282910947759, 6: 0.03914373833251773, 7: -0.995761652421108, 8: 1.0649494418154162, 9: 1.3744782478849122}, 'x2_2': {0: -0.5052556836786106, 1: 1.1464291788297152, 2: -0.5662380273138174, 3: 0.6875729143723538, 4: 0.04653136473130827, 5: -0.012885303852347407, 6: 1.5893672346098884, 7: 0.5464286050059511, 8: -0.10430829457707284, 9: -0.5441755265313813}, 'x2_3': {0: -0.9762973303149007, 1: -0.983731467806563, 2: 1.465827578266328, 3: 0.5325950414202745, 4: -1.4452121324204903, 5: 0.8148816373643869, 6: 0.470791989780882, 7: -0.17951636294180473, 8: 0.7351814781280054, 9: -0.28776723200679066}, 'x3_1': {0: 0.12751822396637064, 1: -0.21926633684030983, 2: 0.15758799357206943, 3: 0.5885412224632464, 4: 0.11916562911189271, 5: -1.6436210334529249, 6: -0.12444368631987467, 7: 1.4618564171802453, 8: 0.6847234328916137, 9: -0.23177118858569187}, 'x3_2': {0: -0.6452955690715819, 1: 1.052094761527654, 2: 0.20190339195326157, 3: 0.6839430295237913, 4: -0.2607691613858866, 5: 0.3315513026670213, 6: 0.015901139336566113, 7: 0.15243420084881903, 8: -0.7604225072161022, 9: -0.4387652927008854}, 'x3_3': {0: -1.067058994377549, 1: 0.8026914180717286, 2: -1.9868531745912268, 3: -0.5057770735303253, 4: -1.6589569342151713, 5: 0.358172252880764, 6: 1.9238983803281329, 7: 2.2518318810978246, 8: -1.2781475121874357, 9: -0.7103081175166167}})
Y = pd.DataFrame.from_dict({'CHOICE': {0: 1.0, 1: 1.0, 2: 2.0, 3: 2.0, 4: 3.0, 5: 2.0, 6: 1.0, 7: 1.0, 8: 2.0, 9: 2.0}})
Z = pd.DataFrame.from_dict({'z1': {0: 2.41967, 1: 2.41, 2: 2.822, 3: 2.82, 4: 2.07, 5: 2.073, 6: 2.04, 7: 2.04, 8: 2.40, 9: 2.40}, 'z2': {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.0, 5: 1.0, 6: 1.0, 7: 1.0, 8: 0.0, 9: 0.0}, 'z3': {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 2.0, 5: 2.0, 6: 2.0, 7: 2.0, 8: 3.0, 9: 3.0}})
id = pd.DataFrame.from_dict({'id_choice': {0: 1.0, 1: 2.0, 2: 3.0, 3: 4.0, 4: 5.0, 5: 6.0, 6: 7.0, 7: 8.0, 8: 9.0, 9: 10.0}, 'id_ind': {0: 1.0, 1: 1.0, 2: 2.0, 3: 2.0, 4: 3.0, 5: 3.0, 6: 4.0, 7: 4.0, 8: 5.0, 9: 5.0}} )
# Create a dataframe with all the data 
data = pd.concat([id, X, Z, Y], axis=1)

Model + Dataset classes


class ChoiceDataset(Dataset):
    '''
    Dataset for choice data

    Args:
        data (pandas dataframe): dataframe with all the data

    Returns:
        dictionary with the data for each individual

    '''
    def __init__(self, data,  args , id_variable:str = "id_ind" ):
        if id_variable not in data.columns:
            raise ValueError(f"Variable {id_variable} not in dataframe")
        self.data = data
        self.args = args
        # select cluster variable
        self.cluster_ids = self.data[id_variable].unique()
        self.Y = torch.LongTensor(self.data['CHOICE'].values -1).reshape(len(self.data['CHOICE'].index),1)
        self.id = torch.LongTensor(self.data[id_variable].values).reshape(len(self.data[id_variable].index),1)
        self.N_n = torch.unique(self.id).shape[0]
        _ , self.t_n = self.id.unique(return_counts=True)
        self.N_t = self.t_n.sum(axis=0).item() 
        self.X_wide = data.filter(regex='^x') 

        if data.filter(regex='^ASC').shape[1] > 0:
            #Select variables that start with "ASC"
            self.ASC_wide = data.filter(regex='^ASC')
            # Concatenate X and ASC
            self.X_wide_plus_ASC = pd.concat([self.X_wide, self.ASC_wide], axis=1)        
            self.K = int(self.X_wide_plus_ASC.shape[1] / args.J)
        else:
            self.X_wide_plus_ASC = self.X_wide
            self.K = int(self.X_wide_plus_ASC.shape[1] / args.J)
        self.X = torch.DoubleTensor(self.X_wide_plus_ASC.values)
        self.Z = torch.DoubleTensor(self.data.filter(regex='^z').values)

    def __len__(self):
        return self.N_n # number of individuals
    
    def __getitem__(self, idx):
        # select the index of the individual
        self.index = torch.where(self.id == idx+1)[0]
        self.len_batch =  self.index.shape[0] 
        Y_batch = self.Y[self.index]
        Z_batch = self.Z[self.index]
        id_batch = self.id[self.index]
        # Number of individuals in the batch
        t_n_batch = self.t_n[idx]
        # Select regressors for the individual
        X_batch = self.X[self.index]
        # reshape X_batch to have the right dimensions
        X_batch = X_batch.reshape(self.len_batch, self.K, self.args.J)
        return {'X': X_batch, 'Z': Z_batch, 'Y': Y_batch, 'id': id_batch, 't_n': t_n_batch}  


def cust_collate(batch_dict):
    '''
    This function is used to concatenate the data for each individual.
    It relies on the function `default_collate()` to concatenate the
    batches of each block of individuals. Later it resizes the tensors
    to concatenate the data for each individual using axis 0.
    
    Parameters
    ----------
    batch_dict : dict
        Dictionary containing the data for each block of individuals.
        Keys are 'X', 'Y', 'Z' and 'id'
    Returns
    -------
    collate_batche : dict
        Dictionary containing the data for each individual.
        Keys are 'X', 'Y', 'Z' and 'id' (again)
    '''

    def resize_batch_from_4D_to_3D_tensor(x:torch.Tensor):
        """
            This function suppresses the extra dimension created by 
            `default_collate()` and concatenates the data for each 
            individual using axis 0 guesing (`-1`) dimension zero.
        Parameters
        ----------
        x : torch.Tensor (4D)
        Returns
        -------
        torch.Tensor (3D)                
        """
        return  x.view(-1, x.size(2), x.size(3))

    def resize_batch_from_3D_to_2D_tensor(y:torch.Tensor):
        """
            This function suppresses the extra dimension created by 
            `default_collate()` and concatenates the data for each 
            individual using axis 0 guesing (`-1`) dimension zero.
        Parameters
        ----------
        x : torch.Tensor (3D)
        Returns
        -------
        torch.Tensor (2D)                
        """
        return  y.view(-1,  y.size(2))    
    collate_batche = default_collate(batch_dict)
    # Resize the tensors to concatenate the data for each individual using axis 0.
    collate_batche['X'] = resize_batch_from_4D_to_3D_tensor(collate_batche['X'])
    collate_batche['Y'] = resize_batch_from_3D_to_2D_tensor(collate_batche['Y'])
    collate_batche['Z'] = resize_batch_from_3D_to_2D_tensor(collate_batche['Z'])
    collate_batche['id'] = resize_batch_from_3D_to_2D_tensor(collate_batche['id'])
    # Number of individuals in the batch
    collate_batche['N_n_batch'] = torch.unique(collate_batche['id']).shape[0]
    # Total Number of choices sets in the batch     
    collate_batche['N_t_batch'] = collate_batche['Y'].shape[0]
    
    return collate_batche



# model  
class LL_MODEL(nn.Module):
    def __init__(self,args):
        super(LL_MODEL, self).__init__()
        self.args = args 
        self.sum_param = self.args.K_r + self.args.K_f 
        assert self.args.K == self.sum_param , "Total number of parameters K is not equal to the sum of the number of random, fixed and taste parameters"

        self.rand_param_list = nn.ParameterList([])
        for i in range(2 * self.args.K_r):           
            beta_rand = nn.Parameter(torch.zeros(1,dtype=torch.double, requires_grad=True))
            self.rand_param_list.append(beta_rand)

        if self.args.K_f > 0:
            self.fix_param_list = nn.ParameterList([])
            for i in range(self.args.K_f): 
                beta_fix = nn.Parameter(torch.zeros(1,dtype=torch.double, requires_grad=True))
                self.fix_param_list.append(beta_fix)

    def forward(self, data):
        '''
        This function defines the forward pass of the model.
        It receives as input the data and the draws from the QMC sequence.
        It returns the log-likelihood of the model.
        ----------------
        Parameters:
            d: (dataset) dictionary with keys: X, Y, id, t_n, Z (if needed))
                X:   (tensor) dimension: (N_t x K x J) [attributes levels]
                Y:   (tensor) dimension: (N_t, J)    [choosen alternative]
                id:  (tensor) dimension: (N_t, 1)    [individual id]
                t_n: (tensor) dimension: (N_n, 1)    [number of choice sets per individual]
                Z:   (tensor) dimension: (N_t, K_t) [individual characteristics]
                Draws: (tensor) dimension: (N_n, J, R)
                    N_n: number of individuals
                    J: number of alternatives
                    R: number of draws
        ----------------                
        Output:
            simulated_log_likelihood: (tensor) dimension: (1,1)
        '''
        self.N_t = data['N_t_batch']
        self.N_n = data['N_n_batch']
        self.K = self.args.K
        #self.Draws = data['Draws']
        self.Draws = Create_Draws(args.K_r, self.N_n, args.R ,data['t_n'], args.J)
        print('self.Draws.shape',self.Draws.shape)

        self.X = data['X']
        self.Y = data['Y']
        self.t_n = data['t_n'].reshape(self.N_n,1)
        self.Z = data['Z'] if data['Z'] is not None else None
        self.id = torch.from_numpy(np.arange(self.N_n)).reshape(self.N_n,1) 

        rand_par = [self.rand_param_list[i] for i in range(2 * self.args.K_r)] 
        if self.args.K_f > 0:
            fix_par  = [self.fix_param_list[i] for i in range(self.args.K_f)]
            self.params = torch.cat(rand_par  + fix_par).reshape(self.args.K_f + 2 * self.args.K_r,1)
        else:
            self.params = torch.cat(rand_par).reshape(2 * self.args.K_r,1)
        
        self.beta_means = self.params[0:2*self.args.K_r:2 ,0].reshape(self.args.K_r,1,1,1)
        self.beta_stds  = self.params[1:2*self.args.K_r:2 ,0].reshape(self.args.K_r,1,1,1)
        self.beta_R = torch.empty(
            self.args.K_r,
            self.N_t, 
            self.args.J, 
            self.args.R)
        for i in range(self.args.K_r):
            self.beta_R[i,:,:,:] = self.beta_means[i,:,:,:] + self.beta_stds[i,:,:,:] * self.Draws[i,:,:,:]
            print('self.beta_R[i,:,:,:].shape',self.beta_R[i,:,:,:].shape)

        if self.args.K_f > 0:
            self.beta_F = self.params[2*self.args.K_r:2*self.args.K_r + self.args.K_f,0].reshape(self.args.K_f,1)
            self.beta_F = self.beta_F.repeat(1, self.N_n * self.args.J * self.args.R).reshape(
                self.args.K_f, 
                self.N_n, 
                self.args.J,
                self.args.R)
            self.beta_F = self.beta_F.repeat_interleave(self.t_n.reshape(len(self.t_n)), dim = 1)

        if self.args.K_f > 0:
            self.all_beta = torch.cat((self.beta_R, self.beta_F), 0)
        else:
            self.all_beta = self.beta_R 
        self.all_X = self.X.transpose(0,1).reshape(
            self.args.K,
            self.N_t,
            self.args.J)
        self.all_X = self.all_X[:,:,:,None].repeat(1,1,1,self.args.R)
        self.V_for_R_draws = torch.einsum(
            'abcd,abcd->bcd', 
            self.all_X.double(), 
            self.all_beta.double()
            )
        self.V_for_R_draws_exp = torch.exp(self.V_for_R_draws)  
        self.sum_of_exp = self.V_for_R_draws_exp.sum(dim=1)
        self.prob_per_draw = self.V_for_R_draws_exp/self.sum_of_exp[:,None,:]
        self.Y_expand  = self.Y[:,:,None].repeat(1,1,self.args.R)
        self.prob_chosen_per_draw = self.prob_per_draw.gather(1,self.Y_expand)
        self.id_expand = self.id[:,:,None].repeat(1,1,self.args.R).to(torch.int64)
        self.prod_seq_choices_ind = torch.ones(
            self.N_n,
            1,
            self.args.R,  
            dtype=self.prob_chosen_per_draw.dtype)
        self.prod_seq_choices_ind = self.prod_seq_choices_ind.scatter_reduce(
            0,    # the dimension to scatter on (0 for rows, 1 for columns)
            self.id_expand,   # the index to scatter
            self.prob_chosen_per_draw, # the value to scatter
            reduce='prod' # the reduction operation
        )

        self.proba_simulated = self.prod_seq_choices_ind.mean(dim=2)
        LL = torch.log(self.proba_simulated).sum()

        return LL

def Create_Draws(dimension,N,R,t_n,J):
    '''
    Create Draws for the random parameters
    input:
        dimension: number of random parameters
        N: number of individuals
        R: number of draws
        t_n: number of choices per individual
        J: number of alternatives
    output:
        normal_draws: tensor of size (dimension, N.repeat(t_n), J, R)
    '''
    np.random.seed(123456789)
    Halton_sampler = qmc.Halton(d=dimension, scramble=True)
    normal_draws = norm.ppf(Halton_sampler.random(n=N * R * J))
    normal_draws = torch.tensor(normal_draws).reshape(dimension, N, J, R) 
    normal_draws = normal_draws.repeat_interleave(t_n.reshape(len(t_n)),1) 
    print("normal_draws shape",normal_draws.shape)
    return normal_draws

Training Loop

Here is where the error is produced when loss.backward().

#torch.autograd.set_detect_anomaly(True)  # enable anomaly detection
# Defining my dataset
DataSet_Choice= ChoiceDataset(data ,args, id_variable = "id_ind" )
# Defining my dataloader
DataLoader_obj = DataLoader(DataSet_Choice, collate_fn=cust_collate, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
# Defining my model 
model = LL_MODEL(args)
# Training loop
for idx, data_batch  in enumerate(DataLoader_obj):
    # Compute the loss
    loss = model(data_batch)
    # Compute the gradient
    loss.backward() #<- this fails even when the loss is computed fine.

# RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0 

Full error message


---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
c:\_scratch_code\_problem_isolation.py in line 299
    297 loss = model(data_batch)
    298 # Compute the gradient
--> 299 loss.backward() #<- this fails even when the loss is computed fine.

File c:\lib\site-packages\torch\_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    387 if has_torch_function_unary(self):
    388     return handle_torch_function(
    389         Tensor.backward,
    390         (self,),
   (...)
    394         create_graph=create_graph,
    395         inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File c:\lib\site-packages\torch\autograd\__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    168     retain_graph = create_graph
    170 # The reason we repeat same the comment below is that
...
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

Thank you.

Double post from here.