Sparse Embedding failing with Adam: torch.cuda.sparse.FloatTensor has no attribute addcmul_

From what I have read so far, it seems the option sparse=True is necessary when tuning the embedding matrix during training, since otherwise the backward step will take a long time. (This was my experience: an average of ~7secs for backward with non-sparse; 0.38 with sparse).

However I have encountered an issue when trying to apply optimization (with Adam):

torch.cuda.sparse.FloatTensor has no attribute addcmul_

I have tried to build a minimalistic model to reproduce the error, it is below:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.embedding = nn.Embedding(10000, 300, sparse=True).cuda()
        self.linear1 = nn.Linear(300, 1).cuda()
        self.optimizer = optim.Adam(self.parameters(), 0.0001)

    def forward(self):
        ixs = Variable(torch.LongTensor([1, 2, 3, 4])).cuda()
        vecs = self.embedding(ixs)
        logits = F.sigmoid(self.linear1(vecs))
        return logits

    def optimize(self, logits):
        labels = Variable(torch.LongTensor([0, 1, 0, 1])).cuda()
        criterion = nn.CrossEntropyLoss().cuda()
        loss = criterion(logits, labels)
        loss.backward()
        self.optimizer.step()

model = MyModel()
logits = model.forward()
model.optimize(logits)

Please let me know if I am doing something wrong! Thanks very much for your kind help.

I have the same problem!

Adam is not suited for sparse gradients, because Adam has a step that densifies things.
Try using RMSProp instead.

This thread has more context:

2 Likes

This sparse argument is not documented in the docstring of the class. What is the purpose of it exactly? Once we encode input vectors from the embedding matrix during batch preparation, isn’t the input becoming just a set of matrix entries? I would be happy if you can clarify.

Thanks.

@ozancaglayan it is for the gradients wrt embeddings. they can be treated as a sparse matrix (only non-zero on the embeddings used in the particular forward)

How do we use this in practice? If the embedding layer is sparse but the rest of the network is dense, do we need to have two different optimizers? Is there any way to do this with param groups.

import time                                                                                                                                                                                                       
import numpy                                                                                                                                                                                                      
import torch                                                                                                                                                                                                      
import torch.nn as nn                                                                                                                                                                                             
import torch.nn.functional as F                                                                                                                                                                                   
import torch.optim as optim                                                                                                                                                                                       
from torch.autograd import Variable                                                                                                                                                                               
from torch.nn.parameter import Parameter                                                                                                                                                                          
                                                                                                                                                                                                                  
                                                                                                                                                                                                                  
class Toy(nn.Module):                                                                                                                                                                                             
    def __init__(self):                                                                                                                                                                                           
        super(Toy, self).__init__()                                                                                                                                                                               
        self.embed = nn.Embedding(100000, 256, sparse=True)                                                                                                                                                       
        self.lin = nn.Linear(256, 256)                                                                                                                                                                            
    def forward(self, idx):                                                                                                                                                                                       
        return self.lin(F.sigmoid(self.embed(idx)))                                                                                                                                                               
                                                                                                                                                                                                                  
                                                                                                                                                                                                                  
toy = Toy().cuda()                                                                                                                                                                                                
optimizer = optim.SparseAdam(toy.parameters(), lr=0.001)                                                                                                                                                          
criterion = nn.CrossEntropyLoss().cuda()                                                                                                                                                                          
                                                                                                                                                                                                                  
                                                                                                                                                                                                                  
x = Variable(torch.from_numpy(numpy.random.randint(0, 100000, 256)).cuda(),                                                                                                                                       
                 requires_grad=False                                                                                                                                                                              
             )                                                                                                                                                                                                    
                                                                                                                                                                                                                  
t = Variable(torch.LongTensor(numpy.random.randint(0, 2, 256)).cuda(),                                                                                                                                            
        requires_grad=False)                                                                                                                                                                                      
                                                                                                                                                                                                                  
                                                                                                                                                                                                                  
start_time = time.time()                                                                                                                                                                                          
for _ in range(2000):                                                                                                                                                                                             
    y = toy(x)                                                                                                                                                                                                    
    cost = criterion(y, t)                                                                                                                                                                                        
    cost.backward()                                                                                                                                                                                               
    optimizer.step()                                                                                                                                                                                              
    print(time.time() - start_time)  

(fails with RuntimeError: SparseAdam does not support dense gradients, please consider Adam instead)

1 Like

Use SparseAdam for embedding, and use another optimizer for other params

1 Like

It sounds a little cumbersome for something that should be straight forward.

I have the feeling that in the TF version https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer

this is completely embedded in the lower level LazyAdam implementation.
Am I correct ?

Yeah it is a bit cumbersome. We had a discussion on this. IIRC, the conclusion was that for dense grads, it would still be the best to use original Adam, rather than SparseAdam (tf.LazyAdam). Hence we made the dense gradient case a hard error. Hope that this clarifies things.

Just curious, what if when it iterate through the param groups, it uses sparse_adam when grads are sparse and dense adam when they are not? Then just adding sparse=True on a unit would automatically switch that unit to sparse_adam.

As it is now, we have to go through and find all the param groups that we know are going to be sparse, and then pass them to a separate optimizer.

Right. However, I don’t think such mechanism would be a good default. There
are also other optimizers that work with sparse gradients (eg SGD iirc) and
we expect to have more in future so restricting to using SparseAdam isn’t a
good choice.

That said, I agree that this is a bit clunky at the moment. Perhaps some
way to combine optimizers into one will be good.

1 Like