When to densify sparse tensors?

For a dataset that uses sparse tensors, should I be calling tensor.to_dense() in dataset.getitem() or in my collate fn?

Background:

I’ve noticed that making datasets that use [still beta] sparse tensors leads to a lot of this:

NotImplementedError: Could not run 'aten::remainder.Tensor' with arguments from the 'SparseCUDA' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::remainder.Tensor' is only available for these backends: [CPU, CUDA, HIP, MPS, IPU, XPU, HPU, VE, Meta, PrivateUse1, PrivateUse2, PrivateUse3, FPGA, ORT, Vulkan, Metal, QuantizedCPU, QuantizedCUDA, QuantizedHIP, QuantizedMPS, QuantizedIPU, QuantizedXPU, QuantizedHPU, QuantizedVE, QuantizedMeta, QuantizedPrivateUse1, QuantizedPrivateUse2, QuantizedPrivateUse3, CustomRNGKeyId, MkldnnCPU, SparseCsrCPU, SparseCsrCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMeta, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradNestedTensor, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PythonDispatcher].

I don’t think calling that in __getitem__ vs collate function will get around the NotImplementedError.

When are you calling remainder? You might have to do torch.remainder(sparse_coo1.to_dense(), sparse_coo2.to_dense()) but I can’t tell because I don’t know how you are using it.

Hi @nivek, It turns out that I had a bug in my collate fn. I have since fixed that but currently I am unable to set pin_memory=True in the DataLoader (I created a separate post specifically about this error here). I’d be very interested in any feedback you have on the training error described (at the very end) below!

DataSet:__getitem__() and collate_fn (works without pin_memory)

Here is __getitem():

    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = {
            'rna':self.rna.__getitem__(idx),
            'atac':self.atac.__getitem__(idx),
            'cell_idx':idx,
            'batch_idx':self.mdata.obs.batch[idx],
            # 'cat_obs':pd.to_numeric(self.mdata.obs[self.cat_obs_names].iloc[idx])
            }
        for i in self.cat_obs_names:
            sample[i] = pd.to_numeric(self.mdata.obs[i].iloc[idx])
        return(sample)

Here is the collate_fn I pass to the DataLoader:

def sparse_batch_collate(batch:list,device:torch.device='cpu'):
    """
    Collate function to transform anndata csr view to pytorch sparse tensor
    """
    if type(batch[0]['atac'].X) == anndata._core.views.SparseCSRView:
        atac_batch = sparse_csr_to_tensor(vstack([x['atac'].X for x in batch])).to_dense().to(device)
    else:
        atac_batch = torch.FloatTensor(vstack([x['atac'].X for x in batch])).to(device)

    if type(batch[0]['rna'].X) == anndata._core.views.SparseCSRView:
        rna_batch = sparse_csr_to_tensor(vstack([x['rna'].X for x in batch])).to_dense().to(device)
    else:
        rna_batch = torch.FloatTensor(vstack([x['rna'].X for x in batch])).to(device)

    batch_idx = torch.tensor([[x['batch_idx']] for x in batch],dtype=torch.int16).to(device)
    cell_idx = torch.tensor([[x['cell_idx']] for x in batch],dtype=torch.int16).to(device)

    batch_collate = {
        'atac':atac_batch,
        'rna':rna_batch,
        'cell_idx':cell_idx,
        'batch_idx':batch_idx,
    }
    for key in batch[0].keys() - set(['rna','atac','cell_idx','batch_idx']):
        batch_collate[key] = torch.tensor([[x[key]] for x in batch],dtype=torch.int16).to(device)
    return(batch_collate)

Working: pin_memory=False

The following DataLoader works and my model trains:

loader = DataLoader(
    mds,
    batch_size=batch_size,
    collate_fn = lambda b: td.sparse_batch_collate(b, device=device),
)

Not Working: pin_memory=True

However, setting pin_memory=True in the above DataLoader generates the following error during training:

Epoch 1 of 10
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[11], line 20
     18 for epoch in range(epochs):
     19     print(f"Epoch {epoch+1} of {epochs}")
---> 20     train_epoch_loss = fit(model, loader)
     21     train_loss.append(train_epoch_loss)
     22     print(f"Train Loss: {train_epoch_loss:.4f}")

Cell In[11], line 4, in fit(model, dataloader)
      1 def fit(model, dataloader):
      2     # model.train()
      3     running_loss = 0.0
----> 4     for batch in dataloader:
      5         optimizer.zero_grad()
      6         # reconstruction, mu, logvar = model(batch[0],batch[1])
      7         # reconstruction = model(batch[0],batch[1])

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:628, in _BaseDataLoaderIter.__next__(self)
    625 if self._sampler_iter is None:
    626     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627     self._reset()  # type: ignore[call-arg]
--> 628 data = self._next_data()
    629 self._num_yielded += 1
    630 if self._dataset_kind == _DatasetKind.Iterable and \
    631         self._IterableDataset_len_called is not None and \
    632         self._num_yielded > self._IterableDataset_len_called:

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
    671 data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672 if self._pin_memory:
--> 673     data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
    674 return data

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:58, in pin_memory(data, device)
     56 elif isinstance(data, collections.abc.Mapping):
     57     try:
---> 58         return type(data)({k: pin_memory(sample, device) for k, sample in data.items()})  # type: ignore[call-arg]
     59     except TypeError:
     60         # The mapping type may not support `__init__(iterable)`.
     61         return {k: pin_memory(sample, device) for k, sample in data.items()}

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:58, in (.0)
     56 elif isinstance(data, collections.abc.Mapping):
     57     try:
---> 58         return type(data)({k: pin_memory(sample, device) for k, sample in data.items()})  # type: ignore[call-arg]
     59     except TypeError:
     60         # The mapping type may not support `__init__(iterable)`.
     61         return {k: pin_memory(sample, device) for k, sample in data.items()}

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:53, in pin_memory(data, device)
     51 def pin_memory(data, device=None):
     52     if isinstance(data, torch.Tensor):
---> 53         return data.pin_memory(device)
     54     elif isinstance(data, string_classes):
     55         return data

RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned

It seems like your collate_fn may be moving data to device=cuda. Once you move it there, you can’t pin it anymore, as the error indicates:
RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned

You can only pin objects that are on CPUs.

Thanks @nivek, can you comment on the relative efficiency of the following two possible strategies?:

  1. densify the sparse CPU tensors during collate (pin_memory=True) → move them to the GPU during training loop
  2. move the sparse CPU tensors to the GPU during collate (pin_memory=False) → densify as needed

Unfortunately I do not know the exact performance numbers, but I do know that it is not recommended to move data to GPU with multiprocessing because of the many subtleties involved. You can read more about it here.