I am trying to run Node2Vec from the torch_geometric.nn library. For reference, I am following this example.
While running the train() function I keep getting TypeError: tuple indices must be integers or slices, not tuple
I am using torch version 1.6.0 with CUDA 10.1 and the latest versions of torch-scatter,torch-sparse,torch-cluster, torch-spline-conv and torch-geometric.
Please excuse my formatting mistakes, this is my first post.
Thanks for any help.
Here is the detailed error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-11-cf6dc0596f64> in <module>()
1 for epoch in range(1, 101):
----> 2 loss = train()
3 print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
4 frames
<ipython-input-10-dfee452d1f45> in train()
3 total_loss = 0
4 # print(loader)
----> 5 for pos_rw, neg_rw in loader:
6 optimizer.zero_grad()
7 loss = model.loss(pos_rw.to(device), neg_rw.to(device))
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
854 else:
855 del self._task_info[idx]
--> 856 return self._process_data(data)
857
858 def _try_put_index(self):
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
879 self._try_put_index()
880 if isinstance(data, ExceptionWrapper):
--> 881 data.reraise()
882 return data
883
/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
393 # (https://bugs.python.org/issue2651), so we work around it.
394 msg = KeyErrorMessage(msg)
--> 395 raise self.exc_type(msg)
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/usr/local/lib/python3.6/dist-packages/torch_geometric/nn/models/node2vec.py", line 116, in sample
return self.pos_sample(batch), self.neg_sample(batch)
File "/usr/local/lib/python3.6/dist-packages/torch_geometric/nn/models/node2vec.py", line 97, in pos_sample
walks.append(rw[:, j:j + self.context_size])
TypeError: tuple indices must be integers or slices, not tuple