`backward()` fails even though loss function is computed correctly

I narrowed down my problem to the following snippet of my training loop (all required functions and dependencies are at the end). I still don’t understand why I can compute the loss function correctly, but the backward() method fails to produce the gradients.

Data + hyperparams

import torch 
from torch.utils.data import Dataset,default_collate , DataLoader
import argparse
import numpy as np
import pandas as pd
import torch.nn as nn
from scipy.stats import qmc , norm
import traceback
X =  pd.DataFrame.from_dict({'x1_1': {0: -0.176, 1: 1.6458, 2: -0.13, 3: 1.96, 4: -1.70, 5: 1.45, 6: 0.06, 7: -1.21, 8: -0.30, 9: 0.07}, 'x1_2': {0: -2.420, 1: -1.0828, 2: 2.73, 3: 1.597, 4: 0.088, 5: 1.220, 6: -0.44, 7: -0.692, 8: 0.037, 9: 0.465}, 'x1_3': {0: -1.5483, 1: 0.8457, 2: -0.21250, 3: 0.52923, 4: -2.593, 5: -0.6188, 6: 1.69, 7: -1.027, 8: 0.63, 9: -0.771}, 'x2_1': {0: 0.379724, 1: -2.2364391598508835, 2: 0.6205947900678905, 3: 0.6623865847688559, 4: 1.562036259999875, 5: -0.13081282910947759, 6: 0.03914373833251773, 7: -0.995761652421108, 8: 1.0649494418154162, 9: 1.3744782478849122}, 'x2_2': {0: -0.5052556836786106, 1: 1.1464291788297152, 2: -0.5662380273138174, 3: 0.6875729143723538, 4: 0.04653136473130827, 5: -0.012885303852347407, 6: 1.5893672346098884, 7: 0.5464286050059511, 8: -0.10430829457707284, 9: -0.5441755265313813}, 'x2_3': {0: -0.9762973303149007, 1: -0.983731467806563, 2: 1.465827578266328, 3: 0.5325950414202745, 4: -1.4452121324204903, 5: 0.8148816373643869, 6: 0.470791989780882, 7: -0.17951636294180473, 8: 0.7351814781280054, 9: -0.28776723200679066}, 'x3_1': {0: 0.12751822396637064, 1: -0.21926633684030983, 2: 0.15758799357206943, 3: 0.5885412224632464, 4: 0.11916562911189271, 5: -1.6436210334529249, 6: -0.12444368631987467, 7: 1.4618564171802453, 8: 0.6847234328916137, 9: -0.23177118858569187}, 'x3_2': {0: -0.6452955690715819, 1: 1.052094761527654, 2: 0.20190339195326157, 3: 0.6839430295237913, 4: -0.2607691613858866, 5: 0.3315513026670213, 6: 0.015901139336566113, 7: 0.15243420084881903, 8: -0.7604225072161022, 9: -0.4387652927008854}, 'x3_3': {0: -1.067058994377549, 1: 0.8026914180717286, 2: -1.9868531745912268, 3: -0.5057770735303253, 4: -1.6589569342151713, 5: 0.358172252880764, 6: 1.9238983803281329, 7: 2.2518318810978246, 8: -1.2781475121874357, 9: -0.7103081175166167}})
Y = pd.DataFrame.from_dict({'CHOICE': {0: 1.0, 1: 1.0, 2: 2.0, 3: 2.0, 4: 3.0, 5: 2.0, 6: 1.0, 7: 1.0, 8: 2.0, 9: 2.0}})
Z = pd.DataFrame.from_dict({'z1': {0: 2.41967, 1: 2.41, 2: 2.822, 3: 2.82, 4: 2.07, 5: 2.073, 6: 2.04, 7: 2.04, 8: 2.40, 9: 2.40}, 'z2': {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 1.0, 5: 1.0, 6: 1.0, 7: 1.0, 8: 0.0, 9: 0.0}, 'z3': {0: 1.0, 1: 1.0, 2: 1.0, 3: 1.0, 4: 2.0, 5: 2.0, 6: 2.0, 7: 2.0, 8: 3.0, 9: 3.0}})
id = pd.DataFrame.from_dict({'id_choice': {0: 1.0, 1: 2.0, 2: 3.0, 3: 4.0, 4: 5.0, 5: 6.0, 6: 7.0, 7: 8.0, 8: 9.0, 9: 10.0}, 'id_ind': {0: 1.0, 1: 1.0, 2: 2.0, 3: 2.0, 4: 3.0, 5: 3.0, 6: 4.0, 7: 4.0, 8: 5.0, 9: 5.0}} )
data = pd.concat([id, X, Z, Y], axis=1)
parser = argparse.ArgumentParser(description='')
parser.add_argument('--R', type=int, default=5, help='Number of draws (default: 100)')
args = parser.parse_args("") 
args.J = 3 # number of alternatives
args.batch_size = 2 # length of the batch
args.K = 3 # number of alternatives
args.K_r = 1 # rand par
args.K_f = 2 # fixed par

Training loop that fails to produce backward()

Here, while looping over the invidiuals I create a matrix with Halton draws for the numerical integration of the random coefficients (draws). Then I add this into the dictionary of the data for this batch new_data_batch['Draws'] = draws. Then when computing the loss function, there is no problem. However, when using the .backward() method, I got a conformability error which I cannot track down. What strikes me as unexpected is that if there would be an error in the dimensionality of the Tensors, then I would get an error when computing the loss function, but this is not the case here, since I only get an error out of the .backward() method. Any ideas about why this might be happening?

torch.autograd.set_detect_anomaly(True)  # enable anomaly detection
# Defining my dataset
DataSet_Choice= ChoiceDataset(data ,args, id_variable = "id_ind" )
# Defining my dataloader
DataLoader_Choice = DataLoader(DataSet_Choice, collate_fn=cust_collate, batch_size=args.batch_size, shuffle=False, num_workers=0, drop_last=False)
# Defining my model 
model = LL_MODEL(args)
# Training loop
for idx, data_batch  in enumerate(DataLoader_Choice):
    print("idx",idx)
    # Create draws for numerical integration of random parameters
    draws = Create_Draws(args.K_r, data_batch['N_n_batch'], args.R ,data_batch['t_n'], args.J)
    # Create a new dictionary with for the data
    new_data_batch = data_batch.copy()  # create a new dictionary
    # Add draws to the new dictionary
    new_data_batch['Draws'] = draws
    # Compute the loss
    loss = model.forward(new_data_batch) 
    print("loss",loss)
    # Compute the gradient
    loss.backward() #<- this fails even when the loss is computed correctly.
 

Full error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
c:\Users\scratch_code\_problem_isolation.py in line 344
    342 print("loss",loss)
    343 # Compute the gradient
--> 344 loss.backward() #<- this fails even when the loss is computed fine.

File c:\Users\lib\site-packages\torch\_tensor.py:396, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    387 if has_torch_function_unary(self):
    388     return handle_torch_function(
    389         Tensor.backward,
    390         (self,),
   (...)
    394         create_graph=create_graph,
    395         inputs=inputs)
--> 396 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File c:\Users\lib\site-packages\torch\autograd\__init__.py:173, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    168     retain_graph = create_graph
    170 # The reason we repeat same the comment below is that
...
--> 173 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    174     tensors, grad_tensors_, retain_graph, create_graph, inputs,
    175     allow_unreachable=True, accumulate_grad=True)

RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 0

Required functions

Data manipulation functions


class ChoiceDataset(Dataset):
    '''
    Dataset for choice data

    Args:
        data (pandas dataframe): dataframe with all the data

    Returns:
        dictionary with the data for each individual

    '''
    def __init__(self, data,  args , id_variable:str = "id_ind" ):

        if id_variable not in data.columns:
            raise ValueError(f"Variable {id_variable} not in dataframe")
        
        self.data = data
        self.args = args
        # select cluster variable
        self.cluster_ids = self.data[id_variable].unique()
        self.Y = torch.LongTensor(self.data['CHOICE'].values -1).reshape(len(self.data['CHOICE'].index),1)
        self.id = torch.LongTensor(self.data[id_variable].values).reshape(len(self.data[id_variable].index),1)
        # number of individuals (N_n)
        self.N_n = torch.unique(self.id).shape[0]
        # number of choices made per individual (t_n)
        _ , self.t_n = self.id.unique(return_counts=True)
        #total number of observations (N_t = total number of choices)
        self.N_t = self.t_n.sum(axis=0).item() 
        # Select regressors: variables that start with "x"
        self.X_wide = data.filter(regex='^x') 

        # Check if there are any ASC variables
        if data.filter(regex='^ASC').shape[1] > 0:
            #Select variables that start with "ASC"
            self.ASC_wide = data.filter(regex='^ASC')
            # Concatenate X and ASC
            self.X_wide_plus_ASC = pd.concat([self.X_wide, self.ASC_wide], axis=1)        
            self.K = int(self.X_wide_plus_ASC.shape[1] / args.J)
        else:
            self.X_wide_plus_ASC = self.X_wide
            self.K = int(self.X_wide_plus_ASC.shape[1] / args.J)
        
        # turn X_wide into a tensor
        self.X = torch.DoubleTensor(self.X_wide_plus_ASC.values)
        # reshape X to have the right dimensions
        # Select variables that start with "z"
        self.Z = torch.DoubleTensor(self.data.filter(regex='^z').values)


    def __len__(self):
        return self.N_n # number of individuals

    def __getitem__(self, idx):
        # select the index of the individual
        self.index = torch.where(self.id == idx+1)[0]
        self.len_batch =  self.index.shape[0] 
        # Select observations for the individual
        Y_batch = self.Y[self.index]
        Z_batch = self.Z[self.index]
        # Select id for the individual
        id_batch = self.id[self.index]
        #print('idx',idx)
        #print('id_batch:',id_batch)
        # Number of individuals in the batch

        # Number of choices per  individual
        t_n_batch = self.t_n[idx]

        # Select regressors for the individual
        X_batch = self.X[self.index]
        # reshape X_batch to have the right dimensions
        X_batch = X_batch.reshape(self.len_batch, self.K, self.args.J)
        return {'X': X_batch, 'Z': Z_batch, 'Y': Y_batch, 'id': id_batch, 't_n': t_n_batch}  


def cust_collate(batch_dict):
    '''
    This function is used to concatenate the data for each individual.
    It relies on the function `default_collate()` to concatenate the
    batches of each block of individuals. Later it resizes the tensors
    to concatenate the data for each individual using axis 0.
    
    Parameters
    ----------
    batch_dict : dict
        Dictionary containing the data for each block of individuals.
        Keys are 'X', 'Y', 'Z' and 'id'
    Returns
    -------
    collate_batche : dict
        Dictionary containing the data for each individual.
        Keys are 'X', 'Y', 'Z' and 'id' (again)
    '''

    def resize_batch_from_4D_to_3D_tensor(x:torch.Tensor):
        """
            This function suppresses the extra dimension created by 
            `default_collate()` and concatenates the data for each 
            individual using axis 0 guesing (`-1`) dimension zero.
        Parameters
        ----------
        x : torch.Tensor (4D)
        Returns
        -------
        torch.Tensor (3D)                
        """
        return  x.view(-1, x.size(2), x.size(3))

    def resize_batch_from_3D_to_2D_tensor(y:torch.Tensor):
        """
            This function suppresses the extra dimension created by 
            `default_collate()` and concatenates the data for each 
            individual using axis 0 guesing (`-1`) dimension zero.
        Parameters
        ----------
        x : torch.Tensor (3D)
        Returns
        -------
        torch.Tensor (2D)                
        """
        return  y.view(-1,  y.size(2))    
    collate_batche = default_collate(batch_dict)
    # Resize the tensors to concatenate the data for each individual using axis 0.
    collate_batche['X'] = resize_batch_from_4D_to_3D_tensor(collate_batche['X'])
    collate_batche['Y'] = resize_batch_from_3D_to_2D_tensor(collate_batche['Y'])
    collate_batche['Z'] = resize_batch_from_3D_to_2D_tensor(collate_batche['Z'])
    collate_batche['id'] = resize_batch_from_3D_to_2D_tensor(collate_batche['id'])
    # Number of individuals in the batch
    collate_batche['N_n_batch'] = torch.unique(collate_batche['id']).shape[0]
    # Total Number of choices sets in the batch     
    collate_batche['N_t_batch'] = collate_batche['Y'].shape[0]
    
    return collate_batche


Model definition

class LL_MODEL(nn.Module):
    def __init__(self,args):
        super(LL_MODEL, self).__init__()
        '''
        '''
        self.args = args 
        self.sum_param = self.args.K_r + self.args.K_f 
        assert self.args.K == self.sum_param , "Total number of parameters K is not equal to the sum of the number of random, fixed and taste parameters"

        self.rand_param_list = nn.ParameterList([])
        for i in range(2 * self.args.K_r):           
            beta_rand = nn.Parameter(torch.zeros(1,dtype=torch.double, requires_grad=True))
            self.rand_param_list.append(beta_rand)

        if self.args.K_f > 0:
            self.fix_param_list = nn.ParameterList([])
            for i in range(self.args.K_f): 
                beta_fix = nn.Parameter(torch.zeros(1,dtype=torch.double, requires_grad=True))
                self.fix_param_list.append(beta_fix)


    def forward(self, data):
        '''
        This function defines the forward pass of the model.
        It receives as input the data and the draws from the QMC sequence.
        It returns the log-likelihood of the model.
        ----------------
        Parameters:
            d: (dataset) dictionary with keys: X, Y, id, t_n, Z (if needed))
                X:   (tensor) dimension: (N_t x K x J) [attributes levels]
                Y:   (tensor) dimension: (N_t, J)    [choosen alternative]
                id:  (tensor) dimension: (N_t, 1)    [individual id]
                t_n: (tensor) dimension: (N_n, 1)    [number of choice sets per individual]
                Z:   (tensor) dimension: (N_t, K_t) [individual characteristics]
                Draws: (tensor) dimension: (N_n, J, R)
                    N_n: number of individuals
                    J: number of alternatives
                    R: number of draws
        ----------------                
        Output:
            simulated_log_likelihood: (tensor) dimension: (1,1)
        '''
        self.N_t = data['N_t_batch']
        self.N_n = data['N_n_batch']
        self.K = self.args.K
        Draws = data['Draws']
        self.X = data['X']
        self.Y = data['Y']
        self.t_n = data['t_n'].reshape(self.N_n,1)
        self.Z = data['Z'] if data['Z'] is not None else None
        self.id = torch.from_numpy(np.arange(self.N_n)).reshape(self.N_n,1) 


        rand_par = [self.rand_param_list[i] for i in range(2 * self.args.K_r)] 
        if self.args.K_f > 0:
            fix_par  = [self.fix_param_list[i] for i in range(self.args.K_f)]
            self.params = torch.cat(rand_par  + fix_par).reshape(self.args.K_f + 2 * self.args.K_r,1)
        else:
            self.params = torch.cat(rand_par).reshape(2 * self.args.K_r,1)
        
        self.beta_means = self.params[0:2*self.args.K_r:2 ,0].reshape(self.args.K_r,1,1,1)
        self.beta_stds  = self.params[1:2*self.args.K_r:2 ,0].reshape(self.args.K_r,1,1,1)
        self.beta_R = torch.empty(
            self.args.K_r,
            self.N_t, 
            self.args.J, 
            self.args.R)
        
        for i in range(self.args.K_r):
            self.beta_R[i,:,:,:] = self.beta_means[i,:,:,:] + self.beta_stds[i,:,:,:] * Draws[i,:,:,:]
        if self.args.K_f > 0:
            self.beta_F = self.params[2*self.args.K_r:2*self.args.K_r + self.args.K_f,0].reshape(self.args.K_f,1)
            self.beta_F = self.beta_F.repeat(1, self.N_n * self.args.J * self.args.R).reshape(
                self.args.K_f, 
                self.N_n, 
                self.args.J,
                self.args.R)
            self.beta_F = self.beta_F.repeat_interleave(self.t_n.reshape(len(self.t_n)), dim = 1)

        if self.args.K_f > 0:
            self.all_beta = torch.cat((self.beta_R, self.beta_F), 0)
        else:
            self.all_beta = self.beta_R 
        self.all_X = self.X.transpose(0,1).reshape(
            self.args.K,
            self.N_t,
            self.args.J)
        self.all_X = self.all_X[:,:,:,None].repeat(1,1,1,self.args.R)
        self.V_for_R_draws = torch.einsum(
            'abcd,abcd->bcd', 
            self.all_X.double(), 
            self.all_beta.double()
            )
        self.V_for_R_draws_exp = torch.exp(self.V_for_R_draws)  
        self.sum_of_exp = self.V_for_R_draws_exp.sum(dim=1)
        self.prob_per_draw = self.V_for_R_draws_exp/self.sum_of_exp[:,None,:]
        self.Y_expand  = self.Y[:,:,None].repeat(1,1,self.args.R)
        self.prob_chosen_per_draw = self.prob_per_draw.gather(1,self.Y_expand)
        self.id_expand = self.id[:,:,None].repeat(1,1,self.args.R).to(torch.int64)
        self.prod_seq_choices_ind = torch.ones(
            self.N_n,
            1,
            self.args.R,  
            dtype=self.prob_chosen_per_draw.dtype)
        self.prod_seq_choices_ind = self.prod_seq_choices_ind.scatter_reduce(
            0,    # the dimension to scatter on (0 for rows, 1 for columns)
            self.id_expand,   # the index to scatter
            self.prob_chosen_per_draw, # the value to scatter
            reduce='prod' # the reduction operation
        )

        self.proba_simulated = self.prod_seq_choices_ind.mean(dim=2)
        LL = torch.log(self.proba_simulated).sum()

        return LL

def Create_Draws(dimension,N,R,t_n,J):
    '''
    Create Draws for the random parameters
    input:
        dimension: number of random parameters
        N: number of individuals
        R: number of draws
        t_n: number of choices per individual
        J: number of alternatives
    output:
        normal_draws: tensor of size (dimension, N.repeat(t_n), J, R)

    '''
    np.random.seed(123456789)
    Halton_sampler = qmc.Halton(d=dimension, scramble=True)
    normal_draws = norm.ppf(Halton_sampler.random(n=N * R * J))
    normal_draws = torch.tensor(normal_draws).reshape(dimension, N, J, R) 
    normal_draws = normal_draws.repeat_interleave(t_n.reshape(len(t_n)),1) 
    print("normal_draws shape",normal_draws.shape)
    return normal_draws