Train for number of iterations rather than number of epochs

In a typical training loop in PyTorch, we have a enumerate(dataloader) under epoch loop.

How can I change this to number of iterations rather than epochs?

something like this :slight_smile:

tot_iter = 0
max_iterations = 100

while True:   # if you want to iterate for inf epochs, use this endless loop
    for batch_index, data in enumerate(data_loader, 0):
        ...training, backpropr, etc...
        tot_iter += 1  
        if tot_iter >= max_iterations:
            stop=True
            break
    if stop:
        break
1 Like

Yeah that works for sure. Thanks.