Sparse linear to implement sparseadam

class Net(torch.nn.Module):

    def __init__(self, shape1, shape2):
        super(Net, self).__init__()
        self.shape1 = shape1
        self.shape2 = shape2
        weight = init.kaiming_uniform_(torch.Tensor(wl*wl, wl*wl), a=math.sqrt(5))
        self.weight = torch.nn.Parameter(weight.to_sparse())
        print(self.weight)

    def forward(self, x, size):

        x = x.reshape(size, -1)
        x = torch.sparse.mm(self.weight, x.t())

        return x.reshape(size, w, l)

I tried to implement a linear sparse network. but my initial weight is always set to be zeros? This has been fixed by adding another init.
But,
Is there any example to implement sparse_adam?

Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], requires_grad=True)
Current learning rate: 0.0

01