I’m working on a project which is trained using two different datasets. I’d like to use to Data Loaders to do this; however, to loop through the Data Loaders, I am forced to use the iterable object. This is restrictive because I’d like to sample (without replacement) from each dataset. Other issues: the sizes of the datasets are different, and the batch size for the samples are also different.
I saw a similar question in which a custom dataset was created to remedy this. However, with that, I would not be able to run entire epochs over the larger dataset, and I would not be able to have different batch sizes. Any suggestions?
It is possible to partially iterate over torch.utils.data.DataLoader. Below is a small snippet of code where the training data is iterated over until some condition is met. Finally, the validation data is iterated over until another condition is met. During each iteration of the outer loop, trainloader and valloader will continue from where they left off in the previous loop.
If you were using multiple datasets, you could change the condition to account for the proportion you want, fill some buffers with your data, and then act on the buffered data.
for i in range(N):
for data_idx, data in enumerate(trainloader):
DO SOMETHING
if CONDITION:
break
for data_idx, data in enumerate(valloader):
DO SOMETHING
if CONDITION:
break
I don’t think they will continue, since the loop is re-initialized. If you want to continue from the last break, you would have to use something like a generator. Did you manage to get the desired behavior?
If so, could you share some code?
Here is a small example showing the last samples are dropped:
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
train_dataset = datasets.FakeData(size=10,
image_size=(3, 24, 24),
transform=transforms.ToTensor())
test_dataset = datasets.FakeData(size=10,
image_size=(3, 24, 24),
transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset)
test_loader = DataLoader(test_dataset)
for epoch in range(3):
for batch_idx, (data, target) in enumerate(train_loader):
print('epoch {} train batch {}'.format(epoch, batch_idx))
if batch_idx == 5:
break
for batch_idx, (data, target) in enumerate(test_loader):
print('epoch {} test batch {}'.format(epoch, batch_idx))
if batch_idx == 5:
break
Comment the condition and break statement and you will see that each loader returns 10 samples.