How to add a sparse tensor to a dense parameter quickly

I am doing some benchmarking using EmbeddingBag vs using a nn.linear layer with torch sparse tensors. My data is bag of words tfidf features which are sparse (N x 16 million) where N is number of documents. The two methods I am comparing are

class EmbedLR(nn.Module):
    def __init__(self, vocab_size, alpha, learning_rate):
        super(EmbedLR, self).__init__()
        self.embed = nn.EmbeddingBag(
            num_embeddings=vocab_size,
            embedding_dim=1,
            mode="sum",
            sparse=True
        )

        self.learning_rate = learning_rate

        initial_eta0 = np.sqrt(1.0 / np.sqrt(alpha))
        self.optimal_init = 1.0 / (initial_eta0 * alpha)
        self.optimizer = optim.SGD(self.parameters(), self.learning_rate)

    # gets csr form for input to embedding
    def compute_offsets(self, batch):
        return torch.LongTensor(batch.indices), torch.LongTensor(batch.indptr[:-1]), torch.FloatTensor(batch.data)

    def forward(self, x):
        input, offsets, per_sample_weights = self.compute_offsets(x)
        return self.embed(
            input=input,
            offsets=offsets,
            per_sample_weights=per_sample_weights
        )

    def zero_init(self):
        nn.init.zeros_(self.embed.weight)

class LogisticRegression(nn.Module):
    def __init__(self, vocab_size, alpha, learning_rate):
        super(LogisticRegression, self).__init__()

        self.vocab_size = vocab_size
        self.alpha = alpha
        self.learning_rate = learning_rate

        self.linear = nn.Linear(self.vocab_size, 1, bias=False)

        initial_eta0 = np.sqrt(1.0 / np.sqrt(self.alpha))
        self.optimal_init = 1.0 / (initial_eta0 * self.alpha)
        self.optimizer = optim.SGD(self.parameters(), self.learning_rate)

   # x is a sparse torch tensor
    def forward(self, x):
        return self.linear(x)

    def zero_init(self):
        nn.init.zeros_(self.linear.weight)

The output of running these models on the same data is as follows:

LogisticRegression:

Forward time: 0.6464931964874268
Loss value: 0.6931458711624146
Loss backward time: 0.671107292175293
p.add_(d_p, alpha=-group[lr]) takes 0.0021352767944335938
Optimizer step time: 0.0022284984588623047
Optimizer zero grad time: 0.0012955665588378906
Converge in 1.3219084739685059 seconds
Convergence per doc: 7.284446321532517e-05

EmbedLR:

Forward time: 0.029858112335205078
Loss value: 0.6931458711624146
Loss backward time: 0.0931253433227539
p.add_(d_p, alpha=-group[lr]) takes 3.497199296951294
Optimizer step time: 3.497309923171997
Optimizer zero grad time: 0.004025459289550781
Converge in 3.624892234802246 seconds
Convergence per doc: 0.00019975159722280522

Using EmbeddingBag the forward pass is faster than using a linear layer on a sparse tensor. But the backward pass is slower. The main bottleneck is the time it takes for the optimizer step. Specifically within the SGD optimizer this line of code is much faster

p is the parameter and d_p is the gradient of the parameter (in this case weight vector either linear.weight or embed.weight). In the EmbedLR d_p is a sparse tensor and p is dense tensor and I believe this is why it is slow (dense-sparse addition). It is faster in LogisticRegression as both p and d_p are dense tensors.

I am wondering how I can speed this optimizer step bottle neck up. Is there a good way of doing dense sparse addition. I tried making the parameter p itself sparse like:

self.embed.weight = nn.Parameter(torch.sparse.FloatTensor(16777216, 1))

but got the following error:

RuntimeError: Could not run 'aten::_embedding_bag' with arguments from the 'SparseCPUTensorId' backend. 'aten::_embedding_bag' is only available for these backends: [CPUTensorId, CUDATensorId, VariableTensorId].