Thanks a lot for your suggestion. I actually managed to solve the issue by creating a list of tuples of (data_file_path, desired_index, class_label) and then just returned tuple((torch.load(data_file)[desired_index: desired_index + window_size], label)) and used this to generate batches.
For anyone interested this is the class:
class TimeSeriesDataSet(data.Dataset):
def __init__(self, tensor_dir, transform, window_size, stride):
self.tensor_directory = tensor_dir
self.transform = transform
self.files = os.listdir(tensor_dir)
self.window_size = window_size
self.stride = stride
self.data_tuples = []
for f in self.files:
file = os.path.join(tensor_dir, f)
data, label = torch.load(file)
# pad with zeros with tensor is not of right length
if data.size(0) % self.window_size != 0:
zeros = torch.zeros(abs(self.window_size - (data.size(0) % self.window_size)), data.size(1)).double()
data = torch.cat((data, zeros), axis=0)
idxs = [i for i in range(0, data.size(0) - self.window_size, self.stride)]
if len(idxs) == 0:
continue
for j in idxs:
data_tuple = (file, j, label)
self.data_tuples.append(data_tuple)
shuffle(self.data_tuples)
def __len__(self):
return len(self.data_tuples)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
sample_tuple = self.data_tuples[idx]
sample, _ = torch.load(sample_tuple[0])
label = sample_tuple[2]
sample = sample[sample_tuple[1]: sample_tuple[1] + self.window_size]
if self.transform:
sample = self.transform(sample)
return {'sample': sample, 'label': label}