ok so i have this code for a combination dataset and dataloader:
# combined dataset class
class CombinationDataset(torch.utils.data.DataLoader):
def __init__(self, datasets):
self.datasets = datasets
def __len__(self):
return(sum([dataset.__len__() for dataset in self.datasets]))
def __getitem__(self, indicies):
dataset_idx = indicies[0]
data_idx = indicies[1]
print(indicies)
return self.datasets[dataset_idx].__getitem__(data_idx)
# class that will take in multiple samplers and output batches from a single dataset at a time
class ComboBatchSampler():
def __init__(self, samplers, batch_size, drop_last):
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.samplers = samplers
self.iterators = [iter(sampler) for sampler in self.samplers]
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
# define how many batches we will grab
self.min_batches = min([len(sampler) for sampler in self.samplers])
self.n_batches = self.min_batches * len(self.samplers)
# define which indicies to use for each batch
self.dataset_idxs = []
random.seed(42)
for j in range((self.n_batches//len(self.samplers) + 1)):
loader_inds = list(range(len(self.samplers)))
random.shuffle(loader_inds)
self.dataset_idxs.extend(loader_inds)
self.dataset_idxs = self.dataset_idxs[:self.n_batches]
# return the batch indicies
batch = []
for dataset_idx in self.dataset_idxs:
for idx in self.samplers[dataset_idx]:
batch.append((dataset_idx, idx))
if len(batch) == self.batch_size:
yield (batch)
batch = []
break
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self) -> int:
if self.drop_last:
return (sum([len(sampler) for sampler in self.samplers])) // self.batch_size
else:
return (sum([len(sampler) for sampler in self.samplers]) + self.batch_size - 1) // self.batch_size
i construct my data like this:
train_data_combined = CombinationDataset([train_data_0, train_data_1])
sampler = ComboBatchSampler([torch.utils.data.sampler.RandomSampler(dataset) for dataset in [train_data_0, train_data_1]],
batch_size=4, drop_last=True)
return DataLoader(train_data_combined, sampler=sampler)
however i get this output and error when trying for batch in train_loader:
[(1, 2383), (1, 2094), (1, 1501), (1, 3252)]
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/tmp/ipykernel_1022/1854105238.py in <module>
1 #Try a couple of epochs to make sure it works (or inturrupt after a small amount of time if it does)
2 for epoch in range(1):
----> 3 for batch in train_loader:
4 metrics = testTrial.train_batch(batch, epoch, batch_idx)
5 print(metrics)
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
519 if self._sampler_iter is None:
520 self._reset()
--> 521 data = self._next_data()
522 self._num_yielded += 1
523 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
559 def _next_data(self):
560 index = self._next_index() # may raise StopIteration
--> 561 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
562 if self._pin_memory:
563 data = _utils.pin_memory.pin_memory(data)
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/run/determined/workdir/combined.py in __getitem__(self, indicies)
14 data_idx = indicies[1]
15 print(indicies)
---> 16 return self.datasets[dataset_idx].__getitem__(data_idx)
17
18 # class that will take in multiple samplers and output batches from a single dataset at a time
TypeError: list indices must be integers or slices, not tuple
it looks like its trying to call getitem() with a list of tuples as input and not calling it one time for each entry. how should i solve this?