Iteration and dataset length goes wrong. And data_time_interval is too large. How to solve these question?

I have 4 GPUs and want to use Pytorch DDP for training. The data set I tested has 80 samples in train and 20 samples in val, and the set batch size is 16. I want to print the number of iterations i and the length of the training set during training as args.step_per_epoch = train_data.len(). Below is the code of train and val.

    def train(model, local_rank):
        model_path = '/output/saved_model'
        step = 0
        nr_eval = 0
        dataset = VimeoDataset(mode = 'train')
        sampler = DistributedSampler(dataset)
        train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, drop_last=True, sampler=sampler)
        args.step_per_epoch = train_data.__len__()
        dataset_val = VimeoDataset(mode = 'val')
        val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=4)
        evaluate(model, val_data, nr_eval, local_rank)
        epochlist.append(0)
        model.save_model(model_path, local_rank)
        print('training...')
        time_stamp = time.time()
        for epoch in range(args.epoch):
            sampler.set_epoch(epoch)
            for i, data in enumerate(train_data):
                data_time_interval = time.time() - time_stamp
                time_stamp = time.time()
                data_gpu, flow_gt = data
                data_gpu = data_gpu.to(device, non_blocking=True) / 255.
                flow_gt = flow_gt.to(device, non_blocking=True)
                imgs = data_gpu[:, :6]
                gt = data_gpu[:, 6:9]
                mul = np.cos(step / (args.epoch * args.step_per_epoch) * math.pi) * 0.5 + 0.5
                learning_rate = get_learning_rate(step)
                pred, merged_img, flow, loss_LPIPS, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, learning_rate, mul, True, flow_gt)
                train_time_interval = time.time() - time_stamp
                time_stamp = time.time()
                if local_rank == 0:
                    print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_LPIPS:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss_LPIPS))
                step += 1
            nr_eval += 1
            if nr_eval % 5 == 0:
                evaluate(model, val_data, step, local_rank)
                epochlist.append(nr_eval)
            model.save_model(model_path, local_rank)    
            dist.barrier()

    def evaluate(model, val_data, nr_eval, local_rank):
        psnr_list = []
        time_stamp = time.time()
        for i, data in enumerate(val_data):
            data_gpu, flow_gt = data
            data_gpu = data_gpu.to(device, non_blocking=True) / 255.
            flow_gt = flow_gt.to(device, non_blocking=True)
            imgs = data_gpu[:, :6]
            gt = data_gpu[:, 6:9]
            with torch.no_grad():
                pred, merged_img, flow, loss_LPIPS, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, training=False)
            for j in range(gt.shape[0]):
                psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
                psnr_list.append(psnr)
        
        eval_time_interval = time.time() - time_stamp
        if local_rank == 0:
            print('eval time: {}'.format(eval_time_interval)) 
            print('mean psnr: {}'.format(np.mean(psnr_list)))
            psnrlist.append(np.mean(psnr_list))

The following is part of task training log. You can see that the parameter values such as i and args.step_per_epoch are wrong, always 0/1. What is the cause and how to modify it. In addition, the first time parameter means data_time_interval, so you can also see that data_time_interval is actually greater than train_time_interval. How to reduce data_time_interval to improve training efficiency.

    eval time: 8.030341386795044
    mean psnr: 24.4023611466035
    training...
    epoch:0 0/1 time:0.30+2.96 loss_LPIPS:3.4511e-01
    epoch:1 0/1 time:0.93+0.41 loss_LPIPS:3.3067e-01
    epoch:2 0/1 time:1.29+0.36 loss_LPIPS:3.3386e-01
    epoch:3 0/1 time:4.61+0.37 loss_LPIPS:3.2475e-01
    epoch:4 0/1 time:5.26+0.36 loss_LPIPS:3.1935e-01
    eval time: 0.9092228412628174
    mean psnr: 24.403575688421018
    epoch:5 0/1 time:3.94+0.36 loss_LPIPS:3.5425e-01
    epoch:6 0/1 time:4.75+0.36 loss_LPIPS:3.5130e-01
    epoch:7 0/1 time:3.60+0.36 loss_LPIPS:3.2492e-01
    epoch:8 0/1 time:1.40+0.37 loss_LPIPS:3.4967e-01
    epoch:9 0/1 time:1.12+0.37 loss_LPIPS:3.4065e-01
    ...
    ...
    epoch:90 0/1 time:7.11+0.38 loss_LPIPS:3.1828e-01
    epoch:91 0/1 time:2.66+0.37 loss_LPIPS:2.7712e-01
    epoch:92 0/1 time:6.57+0.36 loss_LPIPS:3.0946e-01
    epoch:93 0/1 time:5.59+0.36 loss_LPIPS:2.5663e-01
    epoch:94 0/1 time:0.76+0.36 loss_LPIPS:2.8386e-01
    eval time: 0.8664124011993408
    mean psnr: 24.744649141015156
    epoch:95 0/1 time:1.89+0.35 loss_LPIPS:2.8509e-01
    epoch:96 0/1 time:2.24+0.36 loss_LPIPS:3.0353e-01
    epoch:97 0/1 time:1.49+0.37 loss_LPIPS:3.0354e-01
    epoch:98 0/1 time:1.34+0.36 loss_LPIPS:2.9313e-01
    epoch:99 0/1 time:1.27+0.36 loss_LPIPS:2.9234e-01
    eval time: 0.8701093196868896
    mean psnr: 24.784283493153136