shivangi
(shivangi)
1
What is the best way to combine train and test data in torchvision and how can we do it?
Can someone explain how can we do it using ConcatDataset
?
I tried the following:
train = datasets.MNIST(root=dirpath, train=True, download=True, transform=trans)
test = datasets.MNIST(root=dirpath, train=False, download=True, transform=trans)
data_list = list()
data_list.append(train)
data_list.append(test)
total = ConcatDataset(data_list)
But it did not work.
Did you get an error?
Your code should work.
Alternatively try this:
train_dataset = datasets.MNIST(root=dirpath, train=True, transform=trans)
test_dataset = datasets.MNIST(root=dirpath, train=False, transform=trans)
dataset = ConcatDataset([train_dataset, test_dataset])
Both codes should be equivalent, so let me know, if mine works.
2 Likes
Can someone explain how can we do it using ConcatDataset
?
DATASET_DIR = '...'
def load_data(data_dir=DATASET_DIR):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]),
train_data = dataset(
root_dir=data_dir,
fold_nums=[1,2,3,4,5],
transforms=transform,
albumentations_package=False
),
test_data = dataset(
root_dir=data_dir,
fold_nums=[6],
transforms=transform,
albumentations_package=False
),
train_loader = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size,
num_workers=num_workers),
test_loader = torch.utils.data.DataLoader(
val_data,
batch_size=batch_size,
num_workers=num_workers)
It is not working here with this code. Why?
dataset = ConcatDataset([train_data, test_data])
ptrblck
4
What kind of error are you seeing?
> ---------------------------------------------------------------------------
> NameError Traceback (most recent call last)
> /tmp/ipykernel_15260/1455441349.py in <module>
> 3 criterion = nn.CrossEntropyLoss()
> 4
> ----> 5 dataset = ConcatDataset([train_data, test_data])
> 6
> 7 num_epochs=10
>
> NameError: name 'train_data' is not defined
ptrblck
6
It seems you did not execute the previous code snippet which creates train_data
.
It looks like it, but I’ve already done that. I have already restarted the kernel.