Gradient Doesnt Computing

So, i have tried to implement Vision Transformers from scratch, but somehow the layer to_qvk and the W_0 keep initialized on None, even though theres not a single detached layer on this code. Can anyone figured out why? thankyou.

def expand_to_batch(tensor, desired_size):
    tile = desired_size // tensor.shape[0]
    return repeat(tensor, 'b ... -> (b tile) ...', tile=tile)

def compute_mhsa(q, k, v, scale_factor=1, mask=None):
    scaled_dot_prod = torch.einsum('... i d , ... j d -> ... i j', q, k)

    attention = torch.softmax(scaled_dot_prod, dim=-1)
    return torch.einsum('... i j , ... j d -> ... i d', attention, v)


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=None):
        super().__init__()
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        _dim = self.dim_head * heads
        self.heads = heads
        self.to_qvk = nn.Linear(dim, _dim * 3, bias=False)
        self.W_0 = nn.Linear(_dim, dim, bias=False)
        self.rearrange =  Rearrange("b h t d -> b t (h d)")

    def forward(self, x, mask=None):
        logger.info("============ MULTI HEAD SELF ATTENTION ============\n")
        assert x.dim() == 3
        qkv = self.to_qvk(x)
        logger.info(f"Hasil QKV \n{qkv}")
        q, k, v = nn.Parameter(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.heads))
        logger.info(f"Hasil : Q \n{q}")
        logger.info(f"Hasil : K \n{k}")
        logger.info(f"Hasil : V \n{v}")
        out = nn.Parameter(compute_mhsa(q, k, v, mask=mask))
        logger.info(f"Hasil Scaled dot Product :  \n{out}")
        out = self.rearrange(out)
        logger.info(f"Hasil Concat: \n{out}")
        out = self.W_0(out)
        logger.info(f"Hasil Linear : \n{out}")
        return out

class TransformerBlock(nn.Module):

    def __init__(self, dim, heads=8, dim_head=None,
                 dim_linear_block=1024, activation=nn.GELU,
                 mhsa=None, prenorm=False):
        super().__init__()
        self.mhsa = MultiHeadSelfAttention(dim=dim, heads=heads, dim_head=dim_head)
        self.prenorm = False
        self.norm_1 = nn.LayerNorm(dim)
        self.gelu = nn.GELU()
    def forward(self, x, mask=None):
        logger.info(f"======================== Transformer Block ========================")
        logger.info("Before MHSA")
        y = self.mhsa(x)
        logger.info(f"After MHSA \n{y}")
        y = self.norm_1(x)
        logger.info(f"After Normalization \n{y}")
        out = self.gelu(y)
        logger.info(f"After GeLU : \n{out}")
        logger.info(f"Hasil Normalization \n {out}")
        return out


class TransformerEncoder(nn.Module):
    def __init__(self, dim, blocks=6, heads=8, dim_head=None, dim_linear_block=1024, prenorm=False):
        super().__init__()
        self.transformer_block = TransformerBlock(dim,heads,dim_head,dim_linear_block,prenorm=prenorm)
    def forward(self, x, mask=None):
        logger.info("======================== Transformer Encoder ========================")
        x = self.transformer_block(x,mask)
        logger.info(f"Hasil Forward Transformer Encoder {x}")
        return x


class ViT(nn.Module):
    def __init__(self, *,
                img_dim = 4,
                in_channels=3,
                patch_dim=2,
                num_classes=3,
                dim=3,
                blocks=1,
                heads=3,
                dim_linear_block=3,
                dim_head=None,
                 transformer=None, 
                 classification=None):
        super().__init__()
        assert img_dim % patch_dim == 0, f'patch size {patch_dim} not divisible by img dim {img_dim}'
        self.num_classes = num_classes
        self.p = patch_dim
        self.classification = classification
        tokens = (img_dim // patch_dim) ** 2
        self.token_dim = in_channels * (patch_dim ** 2)
        self.dim = dim
        self.dim_head = (int(self.dim / heads)) if dim_head is None else dim_head
        self.project_patches = nn.Linear(self.token_dim, self.dim,bias=False)
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.dim))
        self.pos_emb1D = nn.Parameter(torch.randn(tokens + 1, self.dim))

        if self.classification:
            self.mlp_head = nn.Linear(self.dim, num_classes)

        if transformer is None:
            self.transformer = TransformerEncoder(self.dim, blocks=blocks, heads=heads,
                                                dim_head=self.dim_head,
                                                dim_linear_block=dim_linear_block)
        else:
            self.transformer = transformer
        self.softmax = nn.Softmax()
    def forward(self, img, mask=None):
        logger.info("======================== Vision Transformer ========================")
        logger.info(f"Image Shape : \n{img.shape}")
        logger.info(f"Image Input : \n{img}")
        img_patches = rearrange(img,
                                'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.p, patch_y=self.p)
        logger.info(f"Image Patches Shape : \n{img_patches.shape}")
        logger.info(f"Image Patches : \n{img_patches}")
        batch_size, tokens, _ = img_patches.shape
        logger.info(f"Batch Size : \n {batch_size}")
        logger.info(f"Tokens : \n{tokens}")
        img_patches = self.project_patches(img_patches)
        logger.info(f"project patches size : \n{img_patches.shape}")
        logger.info(f"Project Patches : \n{img_patches}")
        img_patches = torch.cat((expand_to_batch(self.cls_token, desired_size=batch_size), img_patches), dim=1)
        logger.info(f"project patches after expand to token : \n{img_patches.shape}")
        logger.info(f"Images Patches after matrix multiply (cat) : \n{img_patches}")
        img_patches = img_patches + self.pos_emb1D[:tokens + 1, :]
        logger.info(f"Image patches after Pos embed with tokens : \n{img_patches}")
        y = self.transformer(img_patches, mask)
        logger.info(f"Image after Transformer : \n{y}")
        out = nn.Linear(self.dim, self.num_classes)
        out = out(y[:,0,:])
        logger.info(f"MLP Head : \n{out}")
        return out

and heres the log info

for layer, weight in vit.named_parameters():
  print(f"==================Layer {layer}==================================")
  if "bias" in layer:
    print(f"Bias {weight}")
    print(f"Bias Gradient {weight.grad}")
  elif "weight" in layer:
    print(f" Weight {weight}")
  print(f"Gradient {weight.grad}")
==================Layer cls_token==================================
Gradient tensor([[[-4.8729e-05,  1.3597e-05,  3.5132e-05]]])
==================Layer pos_emb1D==================================
Gradient tensor([[-4.8729e-05,  1.3597e-05,  3.5132e-05],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00]])
==================Layer project_patches.weight==================================
 Weight Parameter containing:
tensor([[ 0.0490,  0.0490,  0.0490, -0.0480, -0.0480, -0.0480,  0.0490, -0.0490,
          0.0450, -0.0490,  0.0480, -0.0490],
        [ 0.0490, -0.0480, -0.0480,  0.0480, -0.0480,  0.0470, -0.0490, -0.0490,
         -0.0480,  0.0490, -0.0450, -0.0470],
        [ 0.0490, -0.0490,  0.0480, -0.0480, -0.0490,  0.0450, -0.0490, -0.0470,
          0.0490, -0.0450, -0.0490, -0.0490]], requires_grad=True)
Gradient tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
==================Layer transformer.transformer_block.mhsa.to_qvk.weight==================================
 Weight Parameter containing:
tensor([[ 0.0030,  0.0050,  0.0010],
        [ 0.0010, -0.0010, -0.0010],
        [-0.0030, -0.0030, -0.0030],
        [ 0.0030, -0.0030,  0.0030],
        [ 0.0040,  0.0040,  0.0040],
        [-0.0010,  0.0010,  0.0040],
        [-0.0020, -0.0020, -0.0020],
        [ 0.0020, -0.0020,  0.0020],
        [-0.0020, -0.0020,  0.0010]], requires_grad=True)
Gradient None
==================Layer transformer.transformer_block.mhsa.W_0.weight==================================
 Weight Parameter containing:
tensor([[-0.0010,  0.0010,  0.0040],
        [-0.0020,  0.0020, -0.0020],
        [ 0.0020, -0.0010, -0.0010]], requires_grad=True)
Gradient None
==================Layer transformer.transformer_block.norm_1.weight==================================
 Weight Parameter containing:
tensor([-0.0510,  0.0520,  0.0520], requires_grad=True)
Gradient tensor([ 0.1077, -0.1650, -0.2108])
==================Layer transformer.transformer_block.norm_1.bias==================================
Bias Parameter containing:
tensor([-0.0480, -0.0490,  0.0540], requires_grad=True)
Bias Gradient tensor([ 0.3079,  0.1212, -0.2083])
Gradient tensor([ 0.3079,  0.1212, -0.2083])

You are creating the parameters in the forward method only, which creates new leaf tensors without a history. This will detach the previous tensors from their computation graph and you’ll thus see None gradients for previous layers.
The proper approach is to initialize all parameters in the __init__ method and use them in forward.