Model produce NaN tensors

The model starts to produce NaN tensor at the very begging of the model from the embed_x and critical_features computed by torch.index_select function which is very weird. (when the clip_grad_norm is around 4)

Or "RuntimeError(“Function ‘LogSoftmaxBackward0’ returned nan values in its 0th ou
tput.”) " (when the clip_grad_norm is around 1) but I do not use LogSoftmax for any function in my model

I implemented clip_grad_norm, amp.autocast(), GradScaler() during training.

The model is as follows:

def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    # print(f"dk {d_k}")
    # print(f"k with shape after transpose  {k.transpose(-2, -1).shape}")
    attn_logits = torch.matmul(q, k.transpose(-2, -1))
    print(f"attn_logits{attn_logits} with shape {attn_logits.shape}")
    attn_logits = attn_logits / math.sqrt(d_k)
    # if mask is not None:
    #     attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
    attention = F.softmax(attn_logits, dim=-1)
    print(f"attention {attention} with shape {attention.shape}")
    values = torch.matmul(attention.transpose(-2,-1), v)
    # print(f"values with shape {values.shape}")
    return values, attention

class TopkHeadAttention(nn.Module):
    def __init__(self, config):
        super(TopkHeadAttention,self).__init__()
        self.embed_dim = config.in_size * config.topk_heads
        assert self.embed_dim % config.topk_heads == 0, "Embedding dimension must be 0 modulo number of heads."
        self.num_heads = config.topk_heads
        self.head_dim = self.embed_dim // config.topk_heads
        self.num_classes = config.classes
        # self.mask = config.mask
        self.classification = config.topkhead_classification
        self.critical_features_from = config.critical_features_from

        # Stack all weight matrices 1...h together for efficiency
        # Note that in many implementations you see "bias=False" which is optional
        if config.embed_module =="Dense":
            self.embedding = nn.Linear(config.in_size, self.embed_dim)
            self.query = nn.Linear(self.embed_dim, self.embed_dim)
            self.value = nn.Linear(self.embed_dim, self.embed_dim)
    
        elif config.embed_module =="Sparse":
            self.embedding = nn.Sequential(
                nn.Linear(config.in_size, self.embed_dim),
                nn.Dropout(config.dropout),
                nn.LayerNorm(self.embed_dim),
                nn.GELU()
            )
            self.query = nn.Sequential(
                nn.Linear(self.embed_dim, self.embed_dim),
                nn.Dropout(config.dropout),
                nn.LayerNorm(self.embed_dim),
                nn.GELU()
            )
            self.value = nn.Sequential(
                nn.Linear(self.embed_dim, self.embed_dim),
                nn.Dropout(config.dropout),
                nn.LayerNorm(self.embed_dim),
                nn.GELU()
            )

        self.projection_out = nn.Sequential(nn.Linear(self.embed_dim, config.in_size),nn.Dropout(config.dropout))
        self.projection = config.projection

        self.instance_head = nn.Linear(config.in_size, config.classes)
        self.head = nn.Conv1d(config.classes, config.classes, kernel_size=config.in_size)


    def forward(self, x):
        seq_length, _ = x.shape
        c = self.instance_head(x).squeeze()
        _,topk_idx = torch.topk(c, self.num_heads, dim=0)

        embed_x = self.embedding(x)
        print(f"embed_x {embed_x} should be with shape[seq_length,embed_dim]: {embed_x.shape}")
        embed_features = embed_x.reshape(seq_length, self.num_heads, self.head_dim)
        print(f"embed_features {embed_features} after reshape should be with shape[seq_length, num_heads, head_dim]: {embed_features.shape}")
        embed_features = embed_features.permute(1, 0, 2) # [num_heads, seq_length, head_dim]
        # print(f"embed_features after permutation should be with shape[num_heads, seq_length, head_dim]: {embed_features.shape}")
        q = self.query(embed_x).reshape(seq_length, self.num_heads, self.head_dim)
        v = self.value(embed_x).reshape(seq_length, self.num_heads, self.head_dim)
        if self.critical_features_from == "embedding":
            critical_features = torch.stack([torch.index_select(embed_features[i], 0, topk_idx[i]) for i in range(self.num_heads)])
        else:
            critical_features = torch.stack([torch.index_select(x, 0, topk_idx[i]) for i in range(self.num_heads)])
        
        print(f"critical_features {critical_features} should be with shape[num_heads, num_classes, head_dim]: {critical_features.shape}")
        critical_features = critical_features.permute(1,0,2).reshape(self.num_classes, self.embed_dim)
        k = self.query(critical_features).reshape(self.num_classes,self.num_heads,self.head_dim) # [head_dim, num_heads, num_classes]
        # print(f"k should be with shape[num_classes, num_heads, head_dim]: {k.shape}")

        #permute q k, v to [num_heads, seq_length, head_dim]
        q = q.permute(1, 0, 2)
        # print(f"q should be with shape[num_heads, seq_length, head_dim]: {q.shape}")
        k = k.permute(1, 0, 2)
        # print(f"k should be with shape[num_heads, num_classes, head_dim]: {k.shape}")
        v = v.permute(1, 0, 2)
        # print(f"v {v}should be with shape[num_heads, seq_length, head_dim]: {v.shape}")
        values,attention = scaled_dot_product(q, k, v, mask=None)
        values = values.permute(1, 0, 2) # [num_classes, num_heads, head_dim]
        print(f"values {v} should be with shape[num_classes, num_heads, head_dim]: {values.shape}")

        if self.projection:
                
            values = values.reshape(self.num_classes,self.head_dim * self.num_heads)
            print(f"values {values[:,:5]} should be with shape[num_classes, embed_dim]: {values.shape}")
            o = self.projection_out(values)
            print(f"o with projection {o} should be with shape[num_classes, in_size]: {o.shape}")
        else:
            o = torch.sum(values, dim=1)
            print(f"o with torch.sum  {o} should be with shape[num_classes, in_size]: {o.shape}")

        if self.classification:
            o = self.head(o)
        else:
            o = o
              
        return o,attention,topk_idx



class CriticalFeaturesBlock(nn.Module):
    def __init__(self, config):
        super(CriticalFeaturesBlock, self).__init__()
        self.topk_attention = TopkHeadAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.in_size, config.in_size),
            nn.Dropout(config.dropout),
            nn.LayerNorm(config.in_size),
            nn.GELU()
        )
        self.mlp_c = nn.Sequential(
            nn.Linear(config.classes, config.classes),
            nn.Dropout(config.dropout),
            nn.LayerNorm(config.classes),
            nn.GELU()
        )
        self.instance_head = nn.Linear(config.in_size, config.classes)
        self.add_mlp = config.add_mlp 
        self.norm = nn.LayerNorm(config.in_size)
        self.norm_mlp = nn.LayerNorm(config.in_size)
        self.norm_c = nn.LayerNorm(config.classes)
        self.norm_mlp_c = nn.LayerNorm(config.classes)
        self.dropout = nn.Dropout(config.dropout)
        self.num_heads = config.topk_heads
        self.norm_o = nn.LayerNorm(config.in_size)


    def forward(self, x):
        
        o,attention,topk_idx= self.topk_attention(x)
        print(f"the output of topkheadattention: {o} o should be with shape[num_classes, in_size ]: {o.shape}")
        o = self.norm_o(o)
        #method 1 use index_add to add results from topk multihead attention to residual stream 
        #for i in range(self.num_heads):
           # x = x.index_add(0, topk_idx[i], o.float())
        print(f"x in the block :{x} after index_add should be with shape[seq_length, in_size]: {x.shape}")

        #method 2 to add results from topk multihead attention to residual stream.
        xo = torch.zeros_like(x)
        for idxs in topk_idx:
            for i,j in enumerate(idxs):
                xo[j] = xo[j] + o[i]
        x = x + xo
        x = self.norm(x)
        if self.add_mlp:
            x = self.norm_mlp(x + self.mlp(x))
        else:
            x = x 
        return x,attention


class CF_Transformer(nn.Module):
    def __init__(self, config):
        super(CF_Transformer, self).__init__()
        self.config = config
        self.blocks = nn.ModuleList([CriticalFeaturesBlock(config) for _ in range(config.num_layers)])
        self.head = nn.Linear(config.in_size, config.classes)
        # self.head = nn.Conv1d(1, config.classes, kernel_size=config.in_size)       
        # self.head = nn.Conv1d(config.classes, config.classes, kernel_size=config.in_size)
        self.apply(self.init_weights)
        self.classification = config.classification
        
    def init_weights(self, m):
       
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)
            
        elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
            nn.init.kaiming_normal_(m.weight)
            nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        seq_length = x.shape[0]
        attention = []
        for block in self.blocks:
            x,attention_block = block(x)
            attention.append(attention_block)
        
        print(f"x in transformer {x} should be with shape[seq_length,in_size]: {x.shape}")
        output = self.head(x).squeeze()
        # print(f"output should be with shape[1, num_classes, 1]: {output.shape}")
        attention = torch.stack(attention).squeeze()
        if self.classification == "mean":
            output = torch.mean(output,dim=0)
        elif self.classification == "avgpool1d":
            output = F.adaptive_avg_pool1d(output.T, 1).squeeze()
        elif self.classification == "LPPool1d":
            output = F.lp_pool1d(output.T, 2, 1).squeeze()
        elif self.classification == "maxpool1d":
            output = F.adaptive_max_pool1d(output.T, 1).squeeze()
        return output,attention

```

Did you try checking for NaN and inf values before putting the data into the model?

def check_inf_nan(x):
    if torch.any(torch.isinf(x)):
        print("Inf found!")
    if torch.any(torch.isnan(x)):
        print("NaN found!")

no I did not, but I use “torch.autograd.set_detect_anomaly(True)” before tuning the hyperparameters. I use print function in the model, so I can see where and which tensor produce NaN.

Are there any instances in which d_k = q.size()[-1] are zero?

dk is always 1000 which is the feature vector dimension

The error message is as the following state :


ERROR Run 9gx2e7cn errored: RuntimeError("Function 'LogSoftmaxBackward0' returned nan values in its 0th ou
tput.")

/system/user/yitaocai/miniconda3/envs/ml/lib/python3.9/site-packages/torch/autograd/__init__.py:173: UserWarning$
 Error detected in LogSoftmaxBackward0. Traceback of forward call that caused the error:                         
  File "/system/user/yitaocai/miniconda3/envs/ml/lib/python3.9/threading.py", line 937, in _bootstrap            
    self._bootstrap_inner()                                                                                      
  File "/system/user/yitaocai/miniconda3/envs/ml/lib/python3.9/threading.py", line 980, in _bootstrap_inner      
    self.run()                                                                                                   
  File "/system/user/yitaocai/miniconda3/envs/ml/lib/python3.9/threading.py", line 917, in run                   
    self._target(*self._args, **self._kwargs)                                                                    
  File "/system/user/yitaocai/miniconda3/envs/ml/lib/python3.9/site-packages/wandb/agents/pyagent.py", line 300, 
in _run_job                                                                                                      
    self._function()          
   self._function()                                                                                    [11/1809]
  File "/system/user/publicwork/yitaocai/Master_Thesis/tune_topk_multiheadattention.py", line 360, in train      
    train_loss = run_bags(model, train_df, optimizer, criterion, config)                                         
  File "/system/user/publicwork/yitaocai/Master_Thesis/tune_topk_multiheadattention.py", line 83, in run_bags    
    loss = criterion(out, label)                                                                                 
  File "/system/user/yitaocai/miniconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1$
30, in _call_impl                                                                                                
    return forward_call(*input, **kwargs)
  File "/system/user/yitaocai/miniconda3/envs/ml/lib/python3.9/site-packages/torch/nn/modules/loss.py", line 1164
, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/system/user/yitaocai/miniconda3/envs/ml/lib/python3.9/site-packages/torch/nn/functional.py", line 3014,
in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, l
abel_smoothing)
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484809662/work/torch/csrc/autograd/python_anomaly_mod
e.cpp:102.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

In order to debug your code, would be helpful to have a complete example which reproduces the error. For example:

# define custom classes
class CustomModelClass(nn.Module):
    ...

# instantiate model with config that reproduces error
model = CustomModelClass(emb_dim=10000, num_heads=8)

# run model to reproduce error
for i in range(100):
    x=torch.randint(0, 9999, (4, 1000))
    x=model(x)
    loss= ...

Thank you!
I found the problem only happens when I use the W&B sweep hyperparameter tuning process, the training process seems ok.

Did you check if some hyperparameters cause the model to diverge, creating an exploding loss, which eventually overflows?

how do we check that?

regarding training, I found there are still a few loss is NaN, but it seems does not affect training.
I tried to modify the loss to a certain number if it is NaN

            if torch.isnan(loss):
                loss = torch.tensor(1e-6).to(config.device)
            else:
                loss = loss