PyTorch Dataloader stuck on simple Mnist example

Hello I have a simple mnist example set-up. A few days ago it was working perfectly. When I ran it today the Dataloader gets stuck every time. It won’t load a single batch. It just runs forever. Here is a bit of the code:

Prepare and load the data:

train_samples = datasets.ImageFolder('data/train', transforms.ToTensor())
val_samples = datasets.ImageFolder('data/val', transforms.ToTensor())

train_set = DataLoader(train_samples, batch_size=170, shuffle=True, num_workers=4)
val_set = DataLoader(val_samples, batch_size=170, shuffle=False, num_workers=4)

Train loop:

def train(model, optimizer, criterion):
    model.train() # training mode
    running_loss = 0
    running_corrects = 0
    for x,y in train_set:
        optimizer.zero_grad() # make the gradients 0 
        output = model(x) # forward pass
        _, preds = torch.max(output, 1)
        loss = criterion(output, y) # calculate the loss value
        loss.backward() # compute the gradients
        optimizer.step() # uptade network parameters 
        # statistics 
        running_loss += loss.item() * x.size(0)
        running_corrects += torch.sum(preds==y).item()
    epoch_loss = running_loss / len(train_samples) # mean epoch loss 
    epoch_acc = running_corrects / len(train_samples) # mean epoch accuracy
    return epoch_loss, epoch_acc

I have tried setting the workers to 0 and got the same results. The program gets stuck in for x,y in train_set: every time.

Any ideas of why this is the case? Thanks in advance

You can try setting the batch_size to something really small. And then use

img, y = next(iter(train_set))

to see if you are able to extract one batch.

It might also be that your installation of PyTorch for some reason has been tinkered with. What version are you on? Have you tried reinstalling/upgrading?

Thanks for replying. Yesterday I ran the model on CPU a couple times and it worked fine. Then I ran it back on GPU and it worked as expected again. It seems like the problem went away by itself after not working even after several attempts and even after restarting the machine. So… Idk what caused it. Just FYI I have gtx 1080 and I’m running pytorch 0.4. I can’t seem to reproduce the problem.