SparseAdam occur dimension error when input empty features

:bug: Describe the bug

from torch import nn,optim
import torch
import torch.nn.functional as F
from torch.autograd import Variable

def cumsum_1d_with_zero(input: torch.Tensor, include_last: bool = False) -> torch.Tensor:
    output = torch.zeros(input.size(0) + 1, dtype=input.dtype)
    torch.cumsum(input, 0, out=output[1:])
    if include_last:
        return output
    else:
        return output[:-1]

class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()
        self.embedding = nn.EmbeddingBag(4, 300, sparse=True, mode='sum')
        self.linear1 = nn.Linear(300, 1)
        self.optimizer = optim.SparseAdam(self.embedding.parameters(), 0.0001)
        self.dense_optimizer = optim.Adam(self.linear1.parameters(), 0.0001)

    def forward(self,inputs):
        rows = [torch.tensor(input, dtype=torch.int64) for input in inputs]
        lengths = torch.tensor([x.size(0) for x in rows], dtype=torch.int64)
        offsets = cumsum_1d_with_zero(lengths)
        indices = torch.cat(rows)
        embedding_res = self.embedding(indices, offsets)
        print(embedding_res)
        logits = F.sigmoid(self.linear1(embedding_res))
        return logits

    def optimize(self, logits):
        labels = Variable(torch.LongTensor([0, 1, 0, 1]).reshape(4,1).float())
        criterion = nn.BCEWithLogitsLoss()
        loss = criterion(logits, labels)
        print(loss.item())
        loss.backward()
        self.optimizer.step()
        self.dense_optimizer.step()

run with this! but sparseAdam hit bug!

inputs = [[],[],[],[]]
result = model(inputs)
model.optimize(result)
  • Thanks! SparseAdam will hit this bug, but SGD optimizer is ok.

截屏2022-11-21 下午10 26 12

Versions

Collecting environment information…
PyTorch version: 1.8.2+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.8.10 (default, Jun 22 2022, 20:18:18) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1022-aws-x86_64-with-glibc2.29
Is CUDA available: False
CUDA runtime version: 10.1.243
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.0
[pip3] torch==1.8.2+cu111
[conda] Could not collect