The model starts to produce NaN tensor at the very begging of the model from the embed_x and critical_features computed by torch.index_select function which is very weird. (when the clip_grad_norm is around 4)
Or "RuntimeError(“Function ‘LogSoftmaxBackward0’ returned nan values in its 0th ou
tput.”) " (when the clip_grad_norm is around 1) but I do not use LogSoftmax for any function in my model
I implemented clip_grad_norm, amp.autocast(), GradScaler() during training.
The model is as follows:
def scaled_dot_product(q, k, v, mask=None):
d_k = q.size()[-1]
# print(f"dk {d_k}")
# print(f"k with shape after transpose {k.transpose(-2, -1).shape}")
attn_logits = torch.matmul(q, k.transpose(-2, -1))
print(f"attn_logits{attn_logits} with shape {attn_logits.shape}")
attn_logits = attn_logits / math.sqrt(d_k)
# if mask is not None:
# attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
attention = F.softmax(attn_logits, dim=-1)
print(f"attention {attention} with shape {attention.shape}")
values = torch.matmul(attention.transpose(-2,-1), v)
# print(f"values with shape {values.shape}")
return values, attention
class TopkHeadAttention(nn.Module):
def __init__(self, config):
super(TopkHeadAttention,self).__init__()
self.embed_dim = config.in_size * config.topk_heads
assert self.embed_dim % config.topk_heads == 0, "Embedding dimension must be 0 modulo number of heads."
self.num_heads = config.topk_heads
self.head_dim = self.embed_dim // config.topk_heads
self.num_classes = config.classes
# self.mask = config.mask
self.classification = config.topkhead_classification
self.critical_features_from = config.critical_features_from
# Stack all weight matrices 1...h together for efficiency
# Note that in many implementations you see "bias=False" which is optional
if config.embed_module =="Dense":
self.embedding = nn.Linear(config.in_size, self.embed_dim)
self.query = nn.Linear(self.embed_dim, self.embed_dim)
self.value = nn.Linear(self.embed_dim, self.embed_dim)
elif config.embed_module =="Sparse":
self.embedding = nn.Sequential(
nn.Linear(config.in_size, self.embed_dim),
nn.Dropout(config.dropout),
nn.LayerNorm(self.embed_dim),
nn.GELU()
)
self.query = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_dim),
nn.Dropout(config.dropout),
nn.LayerNorm(self.embed_dim),
nn.GELU()
)
self.value = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_dim),
nn.Dropout(config.dropout),
nn.LayerNorm(self.embed_dim),
nn.GELU()
)
self.projection_out = nn.Sequential(nn.Linear(self.embed_dim, config.in_size),nn.Dropout(config.dropout))
self.projection = config.projection
self.instance_head = nn.Linear(config.in_size, config.classes)
self.head = nn.Conv1d(config.classes, config.classes, kernel_size=config.in_size)
def forward(self, x):
seq_length, _ = x.shape
c = self.instance_head(x).squeeze()
_,topk_idx = torch.topk(c, self.num_heads, dim=0)
embed_x = self.embedding(x)
print(f"embed_x {embed_x} should be with shape[seq_length,embed_dim]: {embed_x.shape}")
embed_features = embed_x.reshape(seq_length, self.num_heads, self.head_dim)
print(f"embed_features {embed_features} after reshape should be with shape[seq_length, num_heads, head_dim]: {embed_features.shape}")
embed_features = embed_features.permute(1, 0, 2) # [num_heads, seq_length, head_dim]
# print(f"embed_features after permutation should be with shape[num_heads, seq_length, head_dim]: {embed_features.shape}")
q = self.query(embed_x).reshape(seq_length, self.num_heads, self.head_dim)
v = self.value(embed_x).reshape(seq_length, self.num_heads, self.head_dim)
if self.critical_features_from == "embedding":
critical_features = torch.stack([torch.index_select(embed_features[i], 0, topk_idx[i]) for i in range(self.num_heads)])
else:
critical_features = torch.stack([torch.index_select(x, 0, topk_idx[i]) for i in range(self.num_heads)])
print(f"critical_features {critical_features} should be with shape[num_heads, num_classes, head_dim]: {critical_features.shape}")
critical_features = critical_features.permute(1,0,2).reshape(self.num_classes, self.embed_dim)
k = self.query(critical_features).reshape(self.num_classes,self.num_heads,self.head_dim) # [head_dim, num_heads, num_classes]
# print(f"k should be with shape[num_classes, num_heads, head_dim]: {k.shape}")
#permute q k, v to [num_heads, seq_length, head_dim]
q = q.permute(1, 0, 2)
# print(f"q should be with shape[num_heads, seq_length, head_dim]: {q.shape}")
k = k.permute(1, 0, 2)
# print(f"k should be with shape[num_heads, num_classes, head_dim]: {k.shape}")
v = v.permute(1, 0, 2)
# print(f"v {v}should be with shape[num_heads, seq_length, head_dim]: {v.shape}")
values,attention = scaled_dot_product(q, k, v, mask=None)
values = values.permute(1, 0, 2) # [num_classes, num_heads, head_dim]
print(f"values {v} should be with shape[num_classes, num_heads, head_dim]: {values.shape}")
if self.projection:
values = values.reshape(self.num_classes,self.head_dim * self.num_heads)
print(f"values {values[:,:5]} should be with shape[num_classes, embed_dim]: {values.shape}")
o = self.projection_out(values)
print(f"o with projection {o} should be with shape[num_classes, in_size]: {o.shape}")
else:
o = torch.sum(values, dim=1)
print(f"o with torch.sum {o} should be with shape[num_classes, in_size]: {o.shape}")
if self.classification:
o = self.head(o)
else:
o = o
return o,attention,topk_idx
class CriticalFeaturesBlock(nn.Module):
def __init__(self, config):
super(CriticalFeaturesBlock, self).__init__()
self.topk_attention = TopkHeadAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.in_size, config.in_size),
nn.Dropout(config.dropout),
nn.LayerNorm(config.in_size),
nn.GELU()
)
self.mlp_c = nn.Sequential(
nn.Linear(config.classes, config.classes),
nn.Dropout(config.dropout),
nn.LayerNorm(config.classes),
nn.GELU()
)
self.instance_head = nn.Linear(config.in_size, config.classes)
self.add_mlp = config.add_mlp
self.norm = nn.LayerNorm(config.in_size)
self.norm_mlp = nn.LayerNorm(config.in_size)
self.norm_c = nn.LayerNorm(config.classes)
self.norm_mlp_c = nn.LayerNorm(config.classes)
self.dropout = nn.Dropout(config.dropout)
self.num_heads = config.topk_heads
self.norm_o = nn.LayerNorm(config.in_size)
def forward(self, x):
o,attention,topk_idx= self.topk_attention(x)
print(f"the output of topkheadattention: {o} o should be with shape[num_classes, in_size ]: {o.shape}")
o = self.norm_o(o)
#method 1 use index_add to add results from topk multihead attention to residual stream
#for i in range(self.num_heads):
# x = x.index_add(0, topk_idx[i], o.float())
print(f"x in the block :{x} after index_add should be with shape[seq_length, in_size]: {x.shape}")
#method 2 to add results from topk multihead attention to residual stream.
xo = torch.zeros_like(x)
for idxs in topk_idx:
for i,j in enumerate(idxs):
xo[j] = xo[j] + o[i]
x = x + xo
x = self.norm(x)
if self.add_mlp:
x = self.norm_mlp(x + self.mlp(x))
else:
x = x
return x,attention
class CF_Transformer(nn.Module):
def __init__(self, config):
super(CF_Transformer, self).__init__()
self.config = config
self.blocks = nn.ModuleList([CriticalFeaturesBlock(config) for _ in range(config.num_layers)])
self.head = nn.Linear(config.in_size, config.classes)
# self.head = nn.Conv1d(1, config.classes, kernel_size=config.in_size)
# self.head = nn.Conv1d(config.classes, config.classes, kernel_size=config.in_size)
self.apply(self.init_weights)
self.classification = config.classification
def init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
seq_length = x.shape[0]
attention = []
for block in self.blocks:
x,attention_block = block(x)
attention.append(attention_block)
print(f"x in transformer {x} should be with shape[seq_length,in_size]: {x.shape}")
output = self.head(x).squeeze()
# print(f"output should be with shape[1, num_classes, 1]: {output.shape}")
attention = torch.stack(attention).squeeze()
if self.classification == "mean":
output = torch.mean(output,dim=0)
elif self.classification == "avgpool1d":
output = F.adaptive_avg_pool1d(output.T, 1).squeeze()
elif self.classification == "LPPool1d":
output = F.lp_pool1d(output.T, 2, 1).squeeze()
elif self.classification == "maxpool1d":
output = F.adaptive_max_pool1d(output.T, 1).squeeze()
return output,attention
```