Model Training becomes very slow if I add one outer for loop

I was trying to implement several runs for one model to see the confidence interval of my model prediction. However, it becomes very slow. The following are the sample codes.

    for i in range(opt.ite + 1):
        model = getattr(models, opt.model)()
        if opt.load_model_path:
            model.load(opt.load_model_path)
        model.to(device)

        model.apply(weight_init)

        # optimizer
        op = Optim(model.parameters(), opt)
        optimizer = op._makeOptimizer()

        previous_err = torch.tensor(10000)
        best_epoch = 0
  
        for epoch in range(opt.max_epoch + 1):
          
            # load the data
            for data in zip(datI_loader, datB_loader):
                # train part                   
                loss.backward()
                optimizer.step()
       
            if epoch % 100 == 0:
                test_err = val(model, grid, sol)
                if test_err < previous_err:
                    previous_err = test_err
                    best_epoch = epoch
        test_meter.add(previous_err.to('cpu'))
        epoch_meter.add(best_epoch)

For example, if I take opt.ite to be like 2 or 3, it will becomes extremely slow, which might take 5 or 6 hours, however, I remove the outer for loop, which, looks below,

        model = getattr(models, opt.model)()
        if opt.load_model_path:
            model.load(opt.load_model_path)
        model.to(device)

        model.apply(weight_init)

        # optimizer
        op = Optim(model.parameters(), opt)
        optimizer = op._makeOptimizer()

        previous_err = torch.tensor(10000)
        best_epoch = 0
  
        for epoch in range(opt.max_epoch + 1):
          
            # load the data
            for data in zip(datI_loader, datB_loader):
                # train part                   
                loss.backward()
                optimizer.step()
       
            if epoch % 100 == 0:
                test_err = val(model, grid, sol)
                if test_err < previous_err:
                    previous_err = test_err
                    best_epoch = epoch

It only take about 20 mins to finish.

Any ideas about this problem?

I don’t know what val is doing exactly, but it seems you are appending its output to some object.
If you are not running this code (inside val) in a with torch.no_grad() context, this could store the entire computation graph and thus increase the memory and yield to slowdowns.