I currently have the following Dataset class:
import numpy as np
import torch
class CVRPGraphDataset(torch.utils.data.Dataset):
def __init__(self, data_file, sparse_factor=-1):
self.data_file = data_file
self.sparse_factor = sparse_factor
self.file_lines = open(data_file).read().splitlines()
print(f'Loaded "{data_file}" with {len(self.file_lines)} lines')
def __len__(self):
return len(self.file_lines)
def get_example(self, idx):
line = self.file_lines[idx]
line = line.strip()
capacity = float(line.split()[0])
# Extract points
points = line.split(" points ")[1].split(" demands ")[0]
points = points.split()
# Extract demands
demands = line.split(" demands ")[1].split(" output ")[0]
demands = demands.split()
if len(demands) != len(points) / 2:
raise ValueError(f"Number of demands {len(demands)} are different from number of points {len(points)}")
points = np.array(
[[float(points[i]), float(points[i + 1]), float(demands[i // 2])] for i in range(0, len(points), 2)]
)
# Extract route
full_route = line.split(" output ")[1]
full_route = full_route.split()
full_route = np.array([int(t) for t in full_route[:-1]])
if min(full_route) == 1:
full_route -= 1
return points, full_route
def __getitem__(self, idx):
points, route = self.get_example(idx)
# Return a densely connected graph
adj_matrix = np.zeros((points.shape[0], points.shape[0]))
for i in range(route.shape[0] - 1):
adj_matrix[route[i], route[i + 1]] = 1
# return points, adj_matrix, route
max_route_size = 2*points.shape[0] + 1
pad_size = max_route_size - len(route)
route = np.pad(route, pad_width=(0,pad_size), mode="constant", constant_values=-1)
route_tensor = torch.from_numpy(route).long()
result = (
torch.LongTensor(np.array([idx], dtype=np.int64)),
torch.from_numpy(points).float(),
torch.from_numpy(adj_matrix).float(),
route_tensor,
)
return result
After an update on a seemingly unrelated location (test_step of my LightningModule), I got an IndexError exception at the line:
def get_example(self, idx):
line = self.file_lines[idx]
because idx is greater than the size of my Dataset. My question is: How can I debug this?
The print in the init shows that the right number of samples were loaded from the files:
print(f'Loaded "{data_file}" with {len(self.file_lines)} lines')
I also printed the contents of my files and they are ok. My DataLoaders uses the default sampler, with 12 workers and a batch_size=10 (I load 60/20/20 train/test/val samples).