Trouble with backward - "expanded sparse_tensor.indices"

PyTorch 1.13.0
OS: Windows

I am trying the implement a custom sparse linear layer. I am using torch.sparse.mm() which according to the documentation supports SparseGrad (mentioned here: torch.sparse — PyTorch 1.13 documentation under ‘Supported Operations’). Here I provide small snippet to ask my question with a small x with 4 channels (x_1, x_2, x_3, x_4). I want to run 2 linear models

  1. (x_1, x_3, x_4) against y_1
  2. (x_2, x_3, x_4) against y_2
    (In general I would have 200 - 350 channels where each channel along with the last two will be regressed).

I show the output for the first 2 epochs of MySparseLinear below.

Question:

  1. Why are weight.indices “augmented” with extra entries after the first optimizer.step()?
  2. Am I correct in using the provided optimizers with a COO sparse parameter matrix?

Any help is appreciated.

class MySparseLinear(nn.Module): 
    def __init__(self,
                 weights: torch.Tensor) -> None:
        super().__init__()
        self.indices = torch.Tensor([[0, 1, 3, 3, 4, 4],
                                     [0, 1, 0, 1, 0, 1]])
        self.weight = nn.Parameter(torch.sparse_coo_tensor(indices=self.indices,
                                                           values=weights,
                                                           size=(5, 2),
                                                           dtype=torch.float64))
        
    def forward(self, 
                data: Dict) -> Dict:
        data['pred'] = torch.sparse.mm(self.weight.T, data['x'].T).T
        return data
# runner
# create some data
torch.random.manual_seed(42)
x = torch.arange(12).to(torch.float64).reshape(3, 4)
x1 = x[:, [0, -2, -1]]
x2 = x[:, [1, -2, -1]]
w1 = torch.tensor([[0.1], [0.2], [-0.1]]).to(torch.float64)
w2 = torch.tensor([[0.25], [0.3], [0.2]]).to(torch.float64)
b1 = torch.tensor([1.0]).to(torch.float64)
b2 = torch.tensor([-2.0]).to(torch.float64)
y1 = x1@w1 + b1 + torch.randn((3, 1))*0.01
y2 = x2@w2 + b2 + torch.randn((3, 1))*0.01
y = torch.cat((y1, y2), axis=1)

device = 'cpu'
slin = MySparseLinear(weights=torch.cat((w1[0], w2[0], w1[1], w2[1], w1[2], w2[2])))
slin.to(device)
print(slin.weight.to_dense())
criterion = nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(slin.parameters(), lr=1e-2)
data = {'x': x.to(device)}

for i in range(n_epochs):
    print(f'Epoch {i + 1} ---------')
    xdata = slin(data)
    loss = criterion(xdata['pred'], y.to(device))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f'weight.grad: {slin.weight.grad}')
    print(f'weight: {slin.weight}')
    print()
Output of MySparseLinear for 2 epochs:
Epoch 1 ---------
weight.grad: tensor(indices=tensor([[0, 2, 3, 1, 2, 3],
                       [0, 0, 0, 1, 1, 1]]),
       values=tensor([-24.0478, -36.0758, -42.0898,  60.1412,  72.1628,
                       84.1844]),
       size=(4, 2), nnz=6, dtype=torch.float64, layout=torch.sparse_coo)
weight: Parameter containing:
tensor(indices=tensor([[0, 1, 2, 2, 3, 1, 2, 3],
                       [0, 1, 0, 1, 0, 1, 1, 1]]),
       values=tensor([ 0.3405,  0.2500,  0.5608,  0.3000,  0.3209, -0.6014,
                      -0.7216, -0.6418]),
       size=(4, 2), nnz=8, dtype=torch.float64, layout=torch.sparse_coo,
       requires_grad=True)

Epoch 2 ---------
weight.grad: tensor(indices=tensor([[0, 2, 3, 1, 2, 3],
                       [0, 0, 0, 1, 1, 1]]),
       values=tensor([ 187.1148,  247.9598,  278.3822, -475.3034, -542.6602,
                      -610.0171]),
       size=(4, 2), nnz=6, dtype=torch.float64, layout=torch.sparse_coo)
weight: Parameter containing:
tensor(indices=tensor([[0, 1, 2, 2, 3, 1, 2, 3],
                       [0, 1, 0, 1, 0, 1, 1, 1]]),
       values=tensor([-1.5307,  0.2500, -1.9188,  0.3000, -2.4629,  4.1516,
                       4.7050,  5.4583]),
       size=(4, 2), nnz=8, dtype=torch.float64, layout=torch.sparse_coo,
       requires_grad=True)

I have implemented the required functionality using masking. I would still like to understand what is going on with the above behavior.