Training approach

My training code looks over complicated.

I searched for a good training example, but the heritage from the versions before <0.4 provides lot of noise for my search.
I created the example to explain what I mean:

# set the min-batch size, optimizer and the loss function
opt = optim.Adam(m.parameters(), 0.0001)
loss_fn = nn.NLLLoss()

# set the number of epochs
num_epochs = 1000

# from data loader take the generator it
it = iter(dl)
# grap the first mini batch
mb, yt = next(it)

# train the model for num_epochs
for epoch in range(num_epochs):
    # all good when I still have the examples in dataloader, but
    # bs will be <512 at certain point doing the next() when this happens I will get the exception
    if(mb.shape[0]==bs):   #bs=512
        tup = torch.unbind(mb, dim=1) 

        # Forward pass to calculate the prediction   
        y_hat = m(*tup)

        # loss evaluation
        loss = loss_fn(y_hat, yt)

        # Backward and optimize

        # update params
        mb, yt = next(it)
        it = iter(dl)
        mb, yt = next(it)

    if (epoch+1) % 50 == 0:
        print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

This looks so complicated, since I need to pay attention on remaining size of the dataloader because next() will return at certain point less than bs examples from dataloader.

OK, I could set DataLoader suffle=True. In that case I could use mb, yt= next(iter(dl)) all the time.

Any feedback on training approach would be helpful for me at this moment.


For me, “epoch” has a slightly different meaning: it means how many times you go through the whole dataset. With this, I would do:

n_smaples_seen = 0
for epoch in range(num_epochs):
    for sample in dl:
        # Not needed if your code handle any batch size
        if sample.size(0) != bs:

        # Your training code

        n_samples_seen += 1
        if n_samples_seen %  50 == 0:
              # Some printing
1 Like

Thanks. Simple for loop can make my training code simpler. I my case I rewrote the code like this:

for epoch in range(num_epochs):
    for mb,yt in dl:                 
            print(" epoch...")
         # training code

I learned I do not need necessary to use iter and next.

One epoch is to run trough all the samples. and this was a problem in my original code.

Great !
By the way, unless your code specifically does not handle batch with a size different than the full one, you don’t need the if mb.size(0) != bs: block. Pytorch functions will all work with a batch size that is different and you can move this print outside of the inner loop if the goal is just to have a print !