Tabular Data (DAE + MLP model): nan values while training

Hi guys and girls,

Newbie to pytorch, more experienced with Keras, GBM, … but curious about performance and power of pytorch, so decided to dive into PyTorch.

I picked a shared code for a DAE + MLP from Kaggle competition (Tab-Apr), and reapplied it (somehow successfully) to April’s competition.

Went in dept in the code to understand it and see what I could tweak and learn from PyTorch.
I’ve seen a curious occurrence of “nan” during the training process (after 50 epochs) which puzzles me and hopping you can help me understand why this is occurring.
What I’ve tried (some of the items below might have an impact to delay the “explosion” of nan, but doesn’t solve):

  • Normalisation of DataSet (it was already normalised with L1 but just in case I tried with L2)
  • drop_last=True) of DataLoader
  • Noiser fixed mask probas
  • low eps in model

While debugging, I can see that the nan are caused during training at train function by:

  • first time it appears, it is around the mid-end part of the DataSet (batch ±50 of 79) [all batches after that will also return nan]
  • after one epoch gets any batch with nan, the recurring epochs will get all batches with nan
  • outputs = model(dae.feature(inputs)) [outputs get the nan values]
  • dae.feature(inputs) does not throw any nan
  • inputs doesn’t have any nan

Running out of ideas what could be causing this. Like I said I’ve tried a couple of things that delay the “explosion” or accelerate it, but in my mind, I can’t understand and would like to understand what is causing this.

I will share a snapshot of the code which I think is more relevant, feel free to ask for more detailed info or to share it in another means.
First post here, so eager to learn.

Thanks in advance

Simplified Main (for readability)

dae = TransformerAutoEncoder(
skf = StratifiedKFold(n_splits=SETUP['nfolds'], random_state=SETUP['random_seed'], shuffle=True)

for fold, (train_idx, valid_idx) in enumerate(skf.split(X[:n_training], Y)):
        train_dataset = FeatureDataset(X[:len_train][train_idx], y[:len_train][train_idx])
        valid_dataset = FeatureDataset(X[:len_train][valid_idx], y[:len_train][valid_idx])
        trainloader =, batch_size=batch_size, shuffle=True)#, drop_last=True)
        validloader =, batch_size=batch_size, shuffle=False)
        # MLP Model
        model = Model(num_features=num_features, num_targets=num_targets, hidden_size=hidden_size)
        loss_fn = nn.BCEWithLogitsLoss()
        loss_tr = nn.BCEWithLogitsLoss()
        for epoch in range(SETUP['epochs']):
                train_loss = train_fn(dae, model, optimizer, scheduler, loss_tr, trainloader,epoch, device)
                valid_loss, valid_preds = valid_fn(dae, model, loss_fn, validloader, device)

        # reload model with best result and predict Xtest
        del trainloader, validloader, train_dataset, valid_dataset
        predictions += inference_fn(dae, model, testloader, device)

Train (where nan’s are seen) / Eval / Predict custom functions

def train_fn(dae, model, optimizer, scheduler, loss_fn, dataloader, epoch, device=SETUP['device']):
    all_losses = []
    final_loss = 0  
    noise_maker = SwapNoiseMasker(SETUP['mlp_start_noise']*(SETUP['mlp_noise_decay']**epoch))
    for i, data in enumerate(dataloader):
        inputs, targets = data['x'].to(device), data['y'].to(device)
        inputs, mask = noise_maker.apply(inputs)
        ## **outputs with nan after 50 epochs !!!**
        outputs = model(dae.feature(inputs))
        loss = loss_fn(outputs, targets)        

        final_loss += loss.item()        
    final_loss /= len(dataloader)

    return final_loss

def valid_fn(dae, model, loss_fn, dataloader, device=SETUP['device']):
    final_loss = 0
    valid_preds = []    
    for data in dataloader:
        inputs, targets = data['x'].to(device), data['y'].to(device)
        outputs = model(dae.feature(inputs))
        loss = loss_fn(outputs, targets)
        final_loss += loss.item()        
    final_loss /= len(dataloader)
    valid_preds = np.concatenate(valid_preds)    
    return final_loss, valid_preds

def inference_fn(dae, model, dataloader, device=SETUP['device']):
    preds = []
    for data in dataloader:
        inputs = data['x'].to(device)
        with torch.no_grad():
            outputs = model(dae.feature(inputs))

    preds = np.concatenate(preds).reshape(-1,)  
    return preds

MLP Model

class Model(nn.Module):
    def __init__(self, num_features=3000, num_targets=1, hidden_size=1000):
        super(Model, self).__init__()
        self.batch_norm1 = nn.BatchNorm1d(num_features, eps=1e-15)
        self.dropout1 = nn.Dropout(SETUP['mlp_dropout'])
        self.dense1 = nn.utils.weight_norm(nn.Linear(num_features, hidden_size))
        self.batch_norm2 = nn.BatchNorm1d(hidden_size, eps=1e-15)
        self.dropout2 = nn.Dropout(SETUP['mlp_dropout'])
        self.dense2 = nn.utils.weight_norm(nn.Linear(hidden_size, hidden_size))
        self.batch_norm3 = nn.BatchNorm1d(hidden_size, eps=1e-15)
        self.dropout3 = nn.Dropout(SETUP['mlp_dropout'])
        self.dense3 = nn.utils.weight_norm(nn.Linear(hidden_size, num_targets))
    def forward(self, x):
        x = self.batch_norm1(x)
        x = self.dropout1(x)
        x = F.relu(self.dense1(x))
        x = self.batch_norm2(x)
        x = self.dropout2(x)
        x = F.relu(self.dense2(x))
        x = self.batch_norm3(x)
        x = self.dropout3(x)
        x = self.dense3(x)
        #x = F.relu(self.dense3(x))
        return x

DAE Model

class TransformerEncoder(torch.nn.Module):
    def __init__(self, embed_dim, num_heads, dropout, feedforward_dim):
        self.attn = torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear_1 = torch.nn.Linear(embed_dim, feedforward_dim)
        self.linear_2 = torch.nn.Linear(feedforward_dim, embed_dim)
        self.layernorm_1 = torch.nn.LayerNorm(embed_dim)
        self.layernorm_2 = torch.nn.LayerNorm(embed_dim)
    def forward(self, x_in):
        attn_out, _ = self.attn(x_in, x_in, x_in)
        x = self.layernorm_1(x_in + attn_out)
        ff_out = self.linear_2(torch.nn.functional.relu(self.linear_1(x)))
        x = self.layernorm_2(x + ff_out)
        return x

class TransformerAutoEncoder(torch.nn.Module):
    def __init__(
        #print(f'{hidden_size} == {embed_dim} * {num_subspaces}')
        assert hidden_size == embed_dim * num_subspaces
        self.n_cats = n_cats
        self.n_nums = n_nums
        self.num_subspaces = num_subspaces
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.emphasis = emphasis
        self.task_weights = np.array(task_weights) / sum(task_weights)
        self.mask_loss_weight = mask_loss_weight

        self.excite = torch.nn.Linear(in_features=num_inputs, out_features=hidden_size)
        self.encoder_1 = TransformerEncoder(embed_dim, num_heads, dropout, feedforward_dim)
        self.encoder_2 = TransformerEncoder(embed_dim, num_heads, dropout, feedforward_dim)
        self.encoder_3 = TransformerEncoder(embed_dim, num_heads, dropout, feedforward_dim)
        self.mask_predictor = torch.nn.Linear(in_features=hidden_size, out_features=num_inputs)
        self.reconstructor = torch.nn.Linear(in_features=hidden_size + num_inputs, out_features=num_inputs)

    def divide(self, x):
        #print('divide function inside TransformerAutoEncoder')
        batch_size = x.shape[0]
        x = x.reshape((batch_size, self.num_subspaces, self.embed_dim)).permute((1, 0, 2))
        return x

    def combine(self, x):
        #print('combine function inside TransformerAutoEncoder')
        batch_size = x.shape[1]
        x = x.permute((1, 0, 2)).reshape((batch_size, -1))
        return x

    def forward(self, x):
        #print('forward function inside TransformerAutoEncoder')
        x = torch.nn.functional.relu(self.excite(x))
        x = self.divide(x)
        x1 = self.encoder_1(x)
        x2 = self.encoder_2(x1)
        x3 = self.encoder_3(x2)
        x = self.combine(x3)
        predicted_mask = self.mask_predictor(x)
        reconstruction = self.reconstructor([x, predicted_mask], dim=1))
        return (x1, x2, x3), (reconstruction, predicted_mask)

    def split(self, t):
        #print('split function inside TransformerAutoEncoder')
        return torch.split(t, [self.n_cats, self.n_nums], dim=1)

    def feature(self, x):
        #print('feature function inside TransformerAutoEncoder')
        #this returns the autoencoder layer outputs as a concatenated feature set
        attn_outs, _ = self.forward(x)
        attn_outs =[self.combine(x) for x in attn_outs], dim=1)
        masks =[x for x in _], dim=1)
        return[attn_outs, masks], dim=1)

    def loss(self, x, y, mask, reduction='mean'):
        #print('got to loss function inside TransformerAutoEncoder')
        _, (reconstruction, predicted_mask) = self.forward(x)
        x_cats, x_nums = self.split(reconstruction)
        y_cats, y_nums = self.split(y)
        w_cats, w_nums = self.split(mask * self.emphasis + (1 - mask) * (1 - self.emphasis))

        cat_loss = self.task_weights[0] * torch.mul(w_cats, bce_logits(x_cats, y_cats, reduction='none'))
        num_loss = self.task_weights[1] * torch.mul(w_nums, mse(x_nums, y_nums, reduction='none'))

        reconstruction_loss =[cat_loss, num_loss], dim=1) if reduction == 'none' else cat_loss.mean() + num_loss.mean()
        mask_loss = self.mask_loss_weight * bce_logits(predicted_mask, mask, reduction=reduction)

        return reconstruction_loss + mask_loss if reduction == 'mean' else [reconstruction_loss, mask_loss]

class SwapNoiseMasker(object):
    def __init__(self, probas):
        self.probas = torch.from_numpy(np.array(probas))

    def apply(self, X):
        #provides a distribution of points where we want to corrupt the data        
        should_swap = torch.bernoulli( * torch.ones((X.shape)).to(X.device))
        #provides a corruped X output
        corrupted_X = torch.where(should_swap == 1, X[torch.randperm(X.shape[0])], X)
        #calculates the mask which we aim to predict
        mask = (corrupted_X != X).float()
        return corrupted_X, mask

DataSet Structure

class FeatureDataset:
    def __init__(self, features, targets):
        self.features = features
        self.targets = targets
    def __len__(self):
        return (self.features.shape[0])
    def __getitem__(self, idx):
        dct = {
            'x' : torch.tensor(self.features[idx, :], dtype=torch.float),
            'y' : torch.tensor(self.targets[idx], dtype=torch.float)            
        return dct

class TestFeatureDataset:
    def __init__(self, features):
        self.features = features
    def __len__(self):
        return (self.features.shape[0])
    def __getitem__(self, idx):
        dct = { 'x' : torch.tensor(self.features[idx, :], dtype=torch.float),    }
        return dct 

Based on your debugging so far I would guess that the NaNs are created inside the model at one point.
A “brute force” approach would be to register forward hooks to all layers and check their output for invalid values to further narrow down the first occurrence of the NaNs.
To do so, you could use this code and use torch.isfinite(tensor).all() inside the hook.
Once you are seeing the first NaN output, you could then also check the inputs to this layer as well as its parameters and check, if the activations is overflowing or why the tensor contains invalid values.

Hi @ptrblck , thanks for your reply.
I performed the hook thanks to your code and could see that the 1st layer doesn’t get any nan.
The 2nd and the 3rd layer do.
The outcome of the debugging after the “explosion”;

dense1 ; Has any nan: False ; Count nan vals: 0 (torch.Size([1024, 2048]) shape)
dense2 ; Has any nan: True ; Count nan vals: 1024 (torch.Size([1024, 2048]) shape)
dense3 ; Has any nan: True ; Count nan vals: 1023 (torch.Size([1024, 1]) shape)

If I read it correctly means that there’s a column which is making the nan (2º layer). After debugging further I can see that this column is 5th (all values there are nan).
Surprisingly in the (3º layer) we don’t get all nan in the output, despite all 5th columns having nan (item 434/1024 has value 2.0079257.

Can you please collaborate this theory, or tell me if I should look elsewhere and if not asking to much point me to whether this looks like a model construction problem or data.

Thanks once again!

Further debugging showed me that it’s not always the same column (but seems to be always only 1 column on layer 2) causing the “explosion” of nan.
Could you please help me understand what might be causing this and how to fix?

Since you’ve now narrowed the first occurrence of the NaNs to the 2nd layer, you could use the hook to check the inputs as well as the module parameters.
The module, input, and output would be pass to the forward hook and I assume you’ve checked the output for NaNs. If so, check the parameters as well as inputs. In case the parameters already contain invalid values, check the gradients for this layer in the previous iteration as they are most likely also containing invalid values.

thanks @ptrblck! I start to narrow down the problem… now just need to understand what is causing it and how to fix it.
so the debugging showed me that @ layer2:

  • inputs don’t have any nan
  • outputs (that’s what I was looking at before, do as mentioned always at a column (which changes according to runs)
  • parameters have nans at weights (more specifically at weights_v

label: bias - torch.Size([2048]) - 0
label: weight_g - torch.Size([2048, 1]) - 0
label: weight_v - torch.Size([2048, 2048]) - 2048 [columns rows with nan: 60, ]

tried m.weight[60,:] and m.weight_v[60,:] which also shows all nan

So the questions now are… what is causing this and how to fix?
Could you help me once again at digging further or explain what is causing this?

Since some of the parameters are already containing NaNs, I guess the last update might have created them and you could thus check the gradients for weight_v from the previous iteration(s).
It also seems that weight_norm is used, but I don’t know if it’s more sensible to gradient explosions etc.

Thanks for the suggestion… apparently it was a problem of using nn.utils.weight_norm.
Removing it worked nicely.
I don’t know though why it was causing the nan’s but at least managed to solve it.