Hi,
I want to concatenate testing samples and training samples (CIFAR-10), and then using this dataset in the test. I used __getitem
:
class MyTestDataset():
def __init__(self, transform_test=None, transform_train=None):
Train = datasets.CIFAR10(root='~/data', train=True,download=True,transform=transform_train)
Test = datasets.CIFAR10(root='~/data', train=False,download=False,transform=transform_test)
self.cifar_len = 10
rand_idx = torch.randperm(len(Train.data))[:self.cifar_len]
self.Train_data = Train.data[rand_idx]
rand_idxt = torch.randperm(len(Test.data))[:self.cifar_len]
self.Test_data = Test.data[rand_idxt]
def __getitem(self, index):
x1, y1 = self.Train_data[index]
x2, y2 = self.Test_data[index]
x = torch.stack((x1, x2))
y = torch.stack((torch.tensor(y1), torch.tensor(y2)))
return x, y
def __len__(self):
return self.cifar_len + self.cifar_len
dataset = MyTestDataset(transform_test=transform_test, transform_train=transform_train)
loader = torch.utils.data.DataLoader(dataset,batch_size=64,shuffle=False, num_workers=8)
and in the test, I used:
for inputs, targets in loader:
but this gives error:
File "...", line 339, in test
for inputs, targets in loader:
File "...", line 637, in __next__
return self._process_next_batch(batch)
File "...", line 658, in _process_next_batch
raise batch.exc_type(batch.exc_msg)
TypeError: Traceback (most recent call last):
File "...", line 138, in _worker_loop
samples = collate_fn([dataset[i] for i in batch_indices])
File "...", line 138, in <listcomp>
samples = collate_fn([dataset[i] for i in batch_indices])
TypeError: 'MyTestDataset' object does not support indexing