Training error when `pin_memory=True` and `collate_fn` passes sparse tensors to the batch

I am unable to set pin_memory=True in the DataLoader. It seems the workaround is to move the data to the GPU within my training/validation loops, but doesn’t that defeat the purpose of pin_memory?

DataSet:__getitem__() and collate_fn (works without pin_memory)

Here is __getitem():

    def __getitem__(self,idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = {
            'rna':self.rna.__getitem__(idx),
            'atac':self.atac.__getitem__(idx),
            'cell_idx':idx,
            'batch_idx':self.mdata.obs.batch[idx],
            # 'cat_obs':pd.to_numeric(self.mdata.obs[self.cat_obs_names].iloc[idx])
            }
        for i in self.cat_obs_names:
            sample[i] = pd.to_numeric(self.mdata.obs[i].iloc[idx])
        return(sample)

Here is the collate_fn I pass to the DataLoader:

def sparse_batch_collate(batch:list,device:torch.device='cpu'):
    """
    Collate function to transform anndata csr view to pytorch sparse tensor
    """
    if type(batch[0]['atac'].X) == anndata._core.views.SparseCSRView:
        atac_batch = sparse_csr_to_tensor(vstack([x['atac'].X for x in batch])).to_dense().to(device)
    else:
        atac_batch = torch.FloatTensor(vstack([x['atac'].X for x in batch])).to(device)

    if type(batch[0]['rna'].X) == anndata._core.views.SparseCSRView:
        rna_batch = sparse_csr_to_tensor(vstack([x['rna'].X for x in batch])).to_dense().to(device)
    else:
        rna_batch = torch.FloatTensor(vstack([x['rna'].X for x in batch])).to(device)

    batch_idx = torch.tensor([[x['batch_idx']] for x in batch],dtype=torch.int16).to(device)
    cell_idx = torch.tensor([[x['cell_idx']] for x in batch],dtype=torch.int16).to(device)

    batch_collate = {
        'atac':atac_batch,
        'rna':rna_batch,
        'cell_idx':cell_idx,
        'batch_idx':batch_idx,
    }
    for key in batch[0].keys() - set(['rna','atac','cell_idx','batch_idx']):
        batch_collate[key] = torch.tensor([[x[key]] for x in batch],dtype=torch.int16).to(device)
    return(batch_collate)

Working: pin_memory=False

The following DataLoader works and my model trains:

loader = DataLoader(
    mds,
    batch_size=batch_size,
    collate_fn = lambda b: td.sparse_batch_collate(b, device=device),
)

Not Working: pin_memory=True

However, setting pin_memory=True in the above DataLoader generates the following error during training:

Epoch 1 of 10
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[11], line 20
     18 for epoch in range(epochs):
     19     print(f"Epoch {epoch+1} of {epochs}")
---> 20     train_epoch_loss = fit(model, loader)
     21     train_loss.append(train_epoch_loss)
     22     print(f"Train Loss: {train_epoch_loss:.4f}")

Cell In[11], line 4, in fit(model, dataloader)
      1 def fit(model, dataloader):
      2     # model.train()
      3     running_loss = 0.0
----> 4     for batch in dataloader:
      5         optimizer.zero_grad()
      6         # reconstruction, mu, logvar = model(batch[0],batch[1])
      7         # reconstruction = model(batch[0],batch[1])

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:628, in _BaseDataLoaderIter.__next__(self)
    625 if self._sampler_iter is None:
    626     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    627     self._reset()  # type: ignore[call-arg]
--> 628 data = self._next_data()
    629 self._num_yielded += 1
    630 if self._dataset_kind == _DatasetKind.Iterable and \
    631         self._IterableDataset_len_called is not None and \
    632         self._num_yielded > self._IterableDataset_len_called:

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
    671 data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    672 if self._pin_memory:
--> 673     data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
    674 return data

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:58, in pin_memory(data, device)
     56 elif isinstance(data, collections.abc.Mapping):
     57     try:
---> 58         return type(data)({k: pin_memory(sample, device) for k, sample in data.items()})  # type: ignore[call-arg]
     59     except TypeError:
     60         # The mapping type may not support `__init__(iterable)`.
     61         return {k: pin_memory(sample, device) for k, sample in data.items()}

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:58, in (.0)
     56 elif isinstance(data, collections.abc.Mapping):
     57     try:
---> 58         return type(data)({k: pin_memory(sample, device) for k, sample in data.items()})  # type: ignore[call-arg]
     59     except TypeError:
     60         # The mapping type may not support `__init__(iterable)`.
     61         return {k: pin_memory(sample, device) for k, sample in data.items()}

File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:53, in pin_memory(data, device)
     51 def pin_memory(data, device=None):
     52     if isinstance(data, torch.Tensor):
---> 53         return data.pin_memory(device)
     54     elif isinstance(data, string_classes):
     55         return data

RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned

I answered that in the other thread but here it is again:

It seems like your collate_fn may be moving data to device=cuda. Once you move it there, you can’t pin it anymore, as the error indicates:
RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned

You can only pin objects that are on CPUs.