Hi @nivek, It turns out that I had a bug in my collate fn. I have since fixed that but currently I am unable to set pin_memory=True
in the DataLoader (I created a separate post specifically about this error here). I’d be very interested in any feedback you have on the training error described (at the very end) below!
DataSet:__getitem__()
and collate_fn
(works without pin_memory
)
Here is __getitem()
:
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample = {
'rna':self.rna.__getitem__(idx),
'atac':self.atac.__getitem__(idx),
'cell_idx':idx,
'batch_idx':self.mdata.obs.batch[idx],
# 'cat_obs':pd.to_numeric(self.mdata.obs[self.cat_obs_names].iloc[idx])
}
for i in self.cat_obs_names:
sample[i] = pd.to_numeric(self.mdata.obs[i].iloc[idx])
return(sample)
Here is the collate_fn
I pass to the DataLoader:
def sparse_batch_collate(batch:list,device:torch.device='cpu'):
"""
Collate function to transform anndata csr view to pytorch sparse tensor
"""
if type(batch[0]['atac'].X) == anndata._core.views.SparseCSRView:
atac_batch = sparse_csr_to_tensor(vstack([x['atac'].X for x in batch])).to_dense().to(device)
else:
atac_batch = torch.FloatTensor(vstack([x['atac'].X for x in batch])).to(device)
if type(batch[0]['rna'].X) == anndata._core.views.SparseCSRView:
rna_batch = sparse_csr_to_tensor(vstack([x['rna'].X for x in batch])).to_dense().to(device)
else:
rna_batch = torch.FloatTensor(vstack([x['rna'].X for x in batch])).to(device)
batch_idx = torch.tensor([[x['batch_idx']] for x in batch],dtype=torch.int16).to(device)
cell_idx = torch.tensor([[x['cell_idx']] for x in batch],dtype=torch.int16).to(device)
batch_collate = {
'atac':atac_batch,
'rna':rna_batch,
'cell_idx':cell_idx,
'batch_idx':batch_idx,
}
for key in batch[0].keys() - set(['rna','atac','cell_idx','batch_idx']):
batch_collate[key] = torch.tensor([[x[key]] for x in batch],dtype=torch.int16).to(device)
return(batch_collate)
Working: pin_memory=False
The following DataLoader works and my model trains:
loader = DataLoader(
mds,
batch_size=batch_size,
collate_fn = lambda b: td.sparse_batch_collate(b, device=device),
)
Not Working: pin_memory=True
However, setting pin_memory=True
in the above DataLoader generates the following error during training:
Epoch 1 of 10
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[11], line 20
18 for epoch in range(epochs):
19 print(f"Epoch {epoch+1} of {epochs}")
---> 20 train_epoch_loss = fit(model, loader)
21 train_loss.append(train_epoch_loss)
22 print(f"Train Loss: {train_epoch_loss:.4f}")
Cell In[11], line 4, in fit(model, dataloader)
1 def fit(model, dataloader):
2 # model.train()
3 running_loss = 0.0
----> 4 for batch in dataloader:
5 optimizer.zero_grad()
6 # reconstruction, mu, logvar = model(batch[0],batch[1])
7 # reconstruction = model(batch[0],batch[1])
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:628, in _BaseDataLoaderIter.__next__(self)
625 if self._sampler_iter is None:
626 # TODO(https://github.com/pytorch/pytorch/issues/76750)
627 self._reset() # type: ignore[call-arg]
--> 628 data = self._next_data()
629 self._num_yielded += 1
630 if self._dataset_kind == _DatasetKind.Iterable and \
631 self._IterableDataset_len_called is not None and \
632 self._num_yielded > self._IterableDataset_len_called:
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
671 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
672 if self._pin_memory:
--> 673 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
674 return data
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:58, in pin_memory(data, device)
56 elif isinstance(data, collections.abc.Mapping):
57 try:
---> 58 return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
59 except TypeError:
60 # The mapping type may not support `__init__(iterable)`.
61 return {k: pin_memory(sample, device) for k, sample in data.items()}
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:58, in (.0)
56 elif isinstance(data, collections.abc.Mapping):
57 try:
---> 58 return type(data)({k: pin_memory(sample, device) for k, sample in data.items()}) # type: ignore[call-arg]
59 except TypeError:
60 # The mapping type may not support `__init__(iterable)`.
61 return {k: pin_memory(sample, device) for k, sample in data.items()}
File /usr/local/python/lib/python3.8/site-packages/torch/utils/data/_utils/pin_memory.py:53, in pin_memory(data, device)
51 def pin_memory(data, device=None):
52 if isinstance(data, torch.Tensor):
---> 53 return data.pin_memory(device)
54 elif isinstance(data, string_classes):
55 return data
RuntimeError: cannot pin 'torch.cuda.FloatTensor' only dense CPU tensors can be pinned