PyTorch Geometric: Accessing batch in a custom "In Memory" dataset fails

When trying to get batch in dataloader I get a key error from torch_geometric\data\storage.py. Please see below:

import torch
from torch_geometric.data import Data, InMemoryDataset
import matplotlib.pyplot as plt
import networkx as nx

Define the edge index and features for the first graph

x1 = torch.tensor([[1, 2], [3, 4]], dtype=torch.float)
edge_index1 = torch.tensor([[0, 1, 1, 0], [1, 0, 0, 1]], dtype=torch.long)

Define the edge index and features for the second graph

x2 = torch.tensor([[2, 3], [1, 4]], dtype=torch.float)
edge_index2 = torch.tensor([[0, 1, 1, 0], [1, 0, 0, 1]], dtype=torch.long)

Create a list of PyTorch Geometric Data objects to represent the graphs

data_list = [Data(x=x1, edge_index=edge_index1),
Data(x=x2, edge_index=edge_index2)]

Define the custom dataset class

class CustomDataset(InMemoryDataset):
def init(self, root, transform=None, pre_transform=None):
super(CustomDataset, self).init(root, transform, pre_transform)
self.data, self.slices = self.collate(data_list)

  @property
  def raw_file_names(self):
      return []

  @property
  def processed_file_names(self):
      return []

  def download(self):
      pass

  def process(self):
      pass

  def __len__(self):
      return len(self.data)

  def get(self, idx):
      # Return the Data object at the specified index
      return self.data[idx]

Initialize the custom dataset

dataset = CustomDataset(root=‘./data’)

Apply transformations to the data if desired

if dataset.transform is not None:
dataset.transform = T.Compose([T.RandomRotate(30, resample=False),
T.RandomTranslate(0.1)])

Put the dataset in a data loader

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

Loop over the dataloader

for batch in dataloader:
# Access the data for each graph in the batch
# x, edge_index, y = batch.x, batch.edge_index, batch.y
x, edge_index = batch.x, batch.edge_index

  # Do something with the data
  print(x, edge_index, y)

######## Error #######
KeyError Traceback (most recent call last)
Cell In[4], line 2
1 # Loop over the dataloader
----> 2 for batch in dataloader:
3 # Access the data for each graph in the batch
4 # x, edge_index, y = batch.x, batch.edge_index, batch.y
5 x, edge_index = batch.x, batch.edge_index
7 # Do something with the data

File ~\Anaconda3\envs\pyg\lib\site-packages\torch\utils\data\dataloader.py:681, in _BaseDataLoaderIter.next(self)
678 if self._sampler_iter is None:
679 # TODO(Bug in dataloader iterator found by mypy · Issue #76750 · pytorch/pytorch · GitHub)
680 self._reset() # type: ignore[call-arg]
→ 681 data = self._next_data()
682 self._num_yielded += 1
683 if self._dataset_kind == _DatasetKind.Iterable and
684 self._IterableDataset_len_called is not None and
685 self._num_yielded > self._IterableDataset_len_called:

File ~\Anaconda3\envs\pyg\lib\site-packages\torch\utils\data\dataloader.py:721, in _SingleProcessDataLoaderIter._next_data(self)
719 def _next_data(self):
720 index = self._next_index() # may raise StopIteration
→ 721 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
722 if self._pin_memory:
723 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~\Anaconda3\envs\pyg\lib\site-packages\torch\utils\data_utils\fetch.py:49, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
47 def fetch(self, possibly_batched_index):
48 if self.auto_collation:
—> 49 data = [self.dataset[idx] for idx in possibly_batched_index]
50 else:
51 data = self.dataset[possibly_batched_index]

File ~\Anaconda3\envs\pyg\lib\site-packages\torch\utils\data_utils\fetch.py:49, in (.0)
47 def fetch(self, possibly_batched_index):
48 if self.auto_collation:
—> 49 data = [self.dataset[idx] for idx in possibly_batched_index]
50 else:
51 data = self.dataset[possibly_batched_index]

File ~\Anaconda3\envs\pyg\lib\site-packages\torch_geometric\data\dataset.py:197, in Dataset.getitem(self, idx)
187 r"““In case :obj:idx is of type integer, will return the data object
188 at index :obj:idx (and transforms it in case :obj:transform is
189 present).
190 In case :obj:idx is a slicing object, e.g., :obj:[2:5], a list, a
191 tuple, or a :obj:torch.Tensor or :obj:np.ndarray of type long or
192 bool, will return a subset of the dataset at the specified indices.””"
193 if (isinstance(idx, (int, np.integer))
194 or (isinstance(idx, Tensor) and idx.dim() == 0)
195 or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
→ 197 data = self.get(self.indices()[idx])
198 data = data if self.transform is None else self.transform(data)
199 return data

Cell In[3], line 43, in CustomDataset.get(self, idx)
41 def get(self, idx):
42 # Return the Data object at the specified index
—> 43 return self.data[idx]

File ~\Anaconda3\envs\pyg\lib\site-packages\torch_geometric\data\data.py:444, in Data.getitem(self, key)
443 def getitem(self, key: str) → Any:
→ 444 return self._store[key]

File ~\Anaconda3\envs\pyg\lib\site-packages\torch_geometric\data\storage.py:81, in BaseStorage.getitem(self, key)
80 def getitem(self, key: str) → Any:
—> 81 return self._mapping[key]

KeyError: 0