Error when using backward on sparse Embedding

Hello,

I’m using libpytorch compiled from github sources (commit ea7bebb7fe5947927a04496ac489a22997c412bd).

I use an embedding layer initialized with sparse=true, and when I try to call the step function of my optimizer, I get the following error

terminate called after throwing an instance of 'c10::Error'
  what():  unsupported tensor layout: Sparse (validate at /home/franck/torch/aten/src/ATen/native/TensorIterator.h:127)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6a (0x7f021cd62eba in /home/franck/.local/lib/libc10.so)
frame #1: <unknown function> + 0x1376148 (0x7f021e0ef148 in /home/franck/.local/lib/libtorch_cpu.so)
frame #2: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&, bool) + 0xd3 (0x7f021e259e03 in /home/franck/.local/lib/libtorch_cpu.so)
frame #3: at::native::mul(at::Tensor const&, at::Tensor const&) + 0x5d (0x7f021dfe343d in /home/franck/.local/lib/libtorch_cpu.so)
frame #4: at::native::mul(at::Tensor const&, c10::Scalar) + 0x4c (0x7f021dfea5bc in /home/franck/.local/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x1756e4c (0x7f021e4cfe4c in /home/franck/.local/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x16ad686 (0x7f021e426686 in /home/franck/.local/lib/libtorch_cpu.so)
frame #7: at::Tensor c10::Dispatcher::callUnboxed<at::Tensor, at::Tensor const&, c10::Scalar>(c10::OperatorHandle const&, at::Tensor const&, c10::Scalar) const + 0x110 (0x7f021e0e1f60 in /home/franck/.local/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x3421b70 (0x7f022019ab70 in /home/franck/.local/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x16ad686 (0x7f021e426686 in /home/franck/.local/lib/libtorch_cpu.so)
frame #10: at::Tensor::mul(c10::Scalar) const + 0x16f (0x7f021e32c21f in /home/franck/.local/lib/libtorch_cpu.so)
frame #11: torch::optim::SGD::step() + 0x59e (0x7f022068a25e in /home/franck/.local/lib/libtorch_cpu.so)
frame #12: main + 0xb84 (0x557e69674e74 in /home/franck/new_macaon/build/dev/dev)
frame #13: __libc_start_main + 0xf3 (0x7f021c954153 in /usr/lib/libc.so.6)
frame #14: _start + 0x2e (0x557e69675ece in /home/franck/new_macaon/build/dev/dev)

Here is my full code (data has been replaced by random tensors)

#include <torch/torch.h>

constexpr int batchSize = 50;
constexpr int nbExamples = 350000;
constexpr int embeddingSize = 20;
constexpr int nbClasses = 15;
constexpr int nbWordsPerDatapoint = 5;
constexpr int maxNbEmbeddings = 200000;

struct Network : torch::nn::Module
{
  torch::nn::Linear linear{nullptr};
  torch::nn::Embedding wordEmbeddings{nullptr};
  Network()
  {
    linear = register_module("linear", torch::nn::Linear(embeddingSize, nbClasses));
    wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true)));
  };
  torch::Tensor forward(torch::Tensor input)
  {
    // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words
    auto embeddingsOfInput = wordEmbeddings(input).mean(1);
    return torch::softmax(linear(embeddingsOfInput),1);
  }
};

int main(int argc, char * argv[])
{
  auto nn = Network();
  torch::optim::SGD optimizer(nn.parameters(), torch::optim::SGDOptions(2e-4));
  for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch)
  {
    optimizer.zero_grad();
    auto batch = torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong);
    auto referenceClasses = torch::zeros(batchSize, at::kLong);
    auto prediction = nn.forward(batch);
    auto loss = torch::nll_loss(torch::log(prediction), referenceClasses);
    loss.backward();
    optimizer.step();
  }
  return 0;
}

This code works when I omit the ‘.sparse(true)’ when creating the options for Embedding layer. However I realy need sparse updates for my embeddings layer because the backward pass is approximately 150 times slower than the forward pass, making the embedding layer unusable.

I would be glad if someone could help me with this issue,
Thanks