Optimize a transformed Embedding with sparse gradients

Hi,

I am trying to use SparseAdam. To be able to do that I am storing all my parameters in embedding modules. So far so good, but before I need to access some of the data using a LongTensor, I need to apply a transformation to my Embedding weights. In a more concrete and isolated algorithmic example I need to do something like this:

from torch import nn
import torch

a = nn.Embedding(num_embeddings=1000, embedding_dim=3, sparse=True)
b = nn.Embedding(num_embeddings=1000, embedding_dim=3, sparse=True)

gt = torch.ones((3, 3))

optimizer_xyz_diffuse = torch.optim.SparseAdam([a.weight] + [b.weight], lr=0.0001)


while True:
    idxa = torch.tensor([1, 2, 3])
    idxb = torch.tensor([4, 5, 6])
    all_idxs = torch.linspace(0, 999, 1000, dtype=torch.int32)
    a2 = 4*a(all_idxs)
    loss = (a2[:3]+b(idxb) - gt).mean()

    loss.backward()

    optimizer_xyz_diffuse.step()
    optimizer_xyz_diffuse.zero_grad()

As you probably guess from the example the gradients of a is using all the values. Is there a proper way to do this? I tried to access the weights directly but this immediately changes my gradient tensor to non-sparse.

In the real example, the transformation is more complicated than a scalar – it’s an MLP – and will influence the indices so it cannot be applied post-filtering.

All the best,
Georgios Kopanas