Test Loader returns the same batch in each iteration

My code is structured like this:

for epoch in range(args.max_epochs):
    model.train(True)
    ...
    for batch_idx, input in enumerate(train_loader):
        ...
    model.eval()
    testframe = next(iter(test_loader))
    sample_frame(model, testframe)

Somehow when generating samples, the batch the generated frames are conditioned on, is the same everytime.

I guess it’s a matter of setting iterator beforehand like
iterator = iter(test_loader)
next(iterator)

maybe this way you are initializing the iterator the whole time

1 Like

Oh, I see. I’ll check if defining the iterator before the training loop will solve the problem, but I’m pretty sure it will. Thx!

Actually I still get the same testing batch, even when I define the Iterator before the training loop :confused:

Hi,
can you check if your database is properly defined?
if you call getitem(idx) manually are you getting different samples?

Why don’t use use the same way of calling the test_loader as the train_loader?

Specifically:

model.eval()
for batch_idx, input in enumerate(test_loader):
    sample_frame(model, input)

or something similar…

Because I need only a single test batch for sampling from my model. But in every training loop iteration (actually every 10th) I would like to get a different test batch.

When I call __getitem__(idx) on the test dataset I get different samples as expected. So maybe it would work to simply pull the samples from the dataset instance without using a Dataloader?

Hi,

My opinion is similar to @JuanFMontesinos’s. I think you should defining the iterator outside these two for-loops, if you only move it outside the training loop, you will get the same samplings again and again.

What you do should like this:

test_iterator = iter(test_loader)
for epoch in range(args.max_epochs):
    model.train(True)
    ...
    for batch_idx, input in enumerate(train_loader):
        ...
    model.eval()
    test_samples = next(test_iterator)
    ...
1 Like

I tried this already after @JuanFMontesinos first reply, but it didn’t work. Now I found that when setting num_workers to >0 it suddenly works. I have no idea why though.

This is the main reason i bet.