I am trying to build my graph CNN model with PyTorch Geometric, but before really working on the model itself, I have found some problems on constructing my dataset with torch_geometric.data.Dataset.
The official doc “Creating Your Own Datasets” gives an example:
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
### and more...
and says that users need to implement the following methods:
torch_geometric.data.InMemoryDataset.raw_file_names():
A list of files in the raw_dir which needs to be found in order to skip the download.
torch_geometric.data.InMemoryDataset.processed_file_names():
A list of files in the processed_dir which needs to be found in order to skip the processing.
In my situation, I have my local dataset so I can prepare all the data paths. So, I tried to make my own dataset as:
class MyOwnDataset(Dataset):
# I pass my data paths as parameter
def __init__(self, root, input_data_paths, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.input_data_paths = input_data_paths
def _download(self):
# Don't need to download
pass
@property
def raw_file_names(self):
# Try to pass my paths here (but got error!)
return self.input_data_paths
@property
def processed_file_names(self):
# Try to replicate names defined in process()
return [ 'data_{}.pt'.format(i) for i in range(len(self.raw_paths)) ]
def __len__(self):
return len(self.processed_file_names)
def process(self):
i = 0
for raw_path in self.raw_paths:
# Test dummy data
data = torch.tensor([1,1,1])
torch.save(data, ops.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
# Test dummy paths
input_paths = ["./p1", "./p2"]
# Create dataset
train_dataset = MyOwnDataset("/tmp/Data/train/", input_paths)
Unfortunately, I got an AttributeError:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-17-d69608eea4a9> in <module>
4 # Create dataset
5 train_params = {'batch_size':1, 'shuffle': True, 'num_workers': 1}
----> 6 train_dataset = MyOwnDataset("/tmp/Data/train/", input_paths)
<ipython-input-16-a9fe3ae7ad91> in __init__(self, root, input_data_paths, transform, pre_transform)
2 def __init__(self, root, input_data_paths,
3 transform=None, pre_transform=None):
----> 4 super(MyOwnDataset, self).__init__(root, transform, pre_transform)
5 self.input_data_paths = input_data_paths
6
/usr/local/lib/python3.6/dist-packages/torch_geometric/data/dataset.py in __init__(self, root, transform, pre_transform, pre_filter)
81
82 self._download()
---> 83 self._process()
84
85 @property
/usr/local/lib/python3.6/dist-packages/torch_geometric/data/dataset.py in _process(self)
119
120 def _process(self):
--> 121 if files_exist(self.processed_paths): # pragma: no cover
122 return
123
/usr/local/lib/python3.6/dist-packages/torch_geometric/data/dataset.py in processed_paths(self)
108 r"""The filepaths to find in the :obj:`self.processed_dir`
109 folder in order to skip the processing."""
--> 110 files = to_list(self.processed_file_names)
111 return [osp.join(self.processed_dir, f) for f in files]
112
<ipython-input-16-a9fe3ae7ad91> in processed_file_names(self)
17 def processed_file_names(self):
18 # Try to replicate names defined in process()
---> 19 return [ 'data_{}.pt'.format(i) for i in range(len(self.raw_paths)) ]
20
21 def __len__(self):
/usr/local/lib/python3.6/dist-packages/torch_geometric/data/dataset.py in raw_paths(self)
101 def raw_paths(self):
102 r"""The filepaths to find in order to skip the download."""
--> 103 files = to_list(self.raw_file_names)
104 return [osp.join(self.raw_dir, f) for f in files]
105
<ipython-input-16-a9fe3ae7ad91> in raw_file_names(self)
12 @property
13 def raw_file_names(self):
---> 14 return self.input_data_paths
15
16 @property
AttributeError: 'MyOwnDataset' object has no attribute 'input_data_paths'
Unlike the official example which defines the raw_file_names explicitly (e.g. [‘some_file_1’, ‘some_file_2’, …]), I want to reuse this Dataset class for several of my datasets so the paths are preferably variables.
Therefore, here comes my question: is there any way to define raw_file_names dynamically?
Besides, I have one more question on defining processed_file_names. I tried to define the names twice (in processed_file_names() and process()) but in the same way so the names match with each other, and I do not believe this would be the best practice. Is there any way to define them dynamically, e.g. depends on input paths?
Sorry for making such a lengthy question and thank you in advance for the help.