Attempted to use an uninitialized parameter in <method 'element_size' of 'torch._C._TensorBase' objects>

I tried to set up this module with lazy linear but still got the error Error: ValueError: Attempted to use an uninitialized parameter in <method 'element_size' of 'torch._C._TensorBase' objects>. This error happens when you are using a LazyModuleor explicitly manipulatingtorch.nn.parameter.UninitializedParameterobjects. When using LazyModules Callforward with a dummy batch to initialize the parameters before calling torch functions, thought I already gave a dummy pass?

class Expert(nn.Module):
    """
    Expert layer for Mixture-of-Experts (MoE) models.

    Attributes:
        l1 (nn.Module): Linear layer for input-to-hidden transformation.
        l2 (nn.Module): Linear layer for hidden-to-output transformation.
        l3 (nn.Module): Additional linear layer for feature transformation.
    """

    def __init__(self, inter_dim: int, device: str = ModelArgs.device):
        """
        Initializes the Expert layer.

        Args:
            dim (int): Input and output dimensionality.
            inter_dim (int): Hidden layer dimensionality.
        """
        super().__init__()  
        self.inter_dim = inter_dim
        self.l1 = nn.LazyLinear(inter_dim, bias=True, device=device)
        self.l3 = nn.LazyLinear(inter_dim, bias=True, device=device)
        self.is_initialized = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the Expert layer.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor after expert computation.
        """
        if not self.is_initialized:
            # Perform a dummy forward pass to initialize LazyLinear layers
            with torch.no_grad():
                _ = self.l1(x)
                _ = self.l3(x)
            input_shape = x.size(1)
            self.l2 = nn.Linear(self.inter_dim, input_shape).to(x.device)
            self.is_initialized = True

        x = F.silu(self.l1(x)) * self.l3(x)
        x = self.l2(x)
        return x