Backward gradient output is zero except class token in Transformer LayerNorm

I added a backward hook to the norm layer before mlp_head, but the grad_output was all zero except the class token.
Can anyone explain this strange thing?

Here is the gradients output
LayerNorm((512,), eps=1e-05, elementwise_affine=True)
grad_output:
(tensor([[[ 0.0218, 0.0294, 0.0096, …, -0.0287, 0.0241, 0.0031],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000]],
[[ 0.0218, 0.0294, 0.0096, …, -0.0287, 0.0241, 0.0031],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, …, 0.0000, 0.0000, 0.0000]]],
device=‘cuda:0’), None, None)

Here is the backward hook function
def hook_fn_backward(module, grad_input, grad_output):
print(module)
print(‘grad_output:\n’, grad_output[0], grad_output[0].shape)
print(‘grad_input:\n’, grad_input[0], grad_input[0].shape)

Here is the network structure
STAM(
(to_patch_embedding): Sequential(
(0): Rearrange(‘b f c (h p1) (w p2) → b f (h w) (p1 p2 c)’, p1=32, p2=32)
(1): Linear(in_features=3072, out_features=512, bias=True)
)
(dropout): Dropout(p=0.0, inplace=False)
(space_transformer): Transformer(
(layers): ModuleList(
(0): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(1): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(2): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(3): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(4): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(5): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(6): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(7): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(8): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(9): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(10): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(11): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(time_transformer): Transformer(
(layers): ModuleList(
(0): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(1): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(2): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(3): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(4): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
(5): ModuleList(
(0): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): Attention(
(to_qkv): Linear(in_features=512, out_features=1536, bias=False)
(to_out): Sequential(
(0): Linear(in_features=512, out_features=512, bias=True)
(1): Dropout(p=0.0, inplace=False)
)
)
)
(1): PreNorm(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(fn): FeedForward(
(net): Sequential(
(0): Linear(in_features=512, out_features=2048, bias=True)
(1): GELU()
(2): Dropout(p=0.0, inplace=False)
(3): Linear(in_features=2048, out_features=512, bias=True)
(4): Dropout(p=0.0, inplace=False)
)
)
)
)
)
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)
(mlp_head): Linear(in_features=512, out_features=100, bias=True)
)