Training "stucks" every three batches (pytorch 2.0)

I was training resnet50 with ImageNet on NVIDIA A40. I found that my training speed slowed down every three batchs then recovered normal speed. Why would this happen?


  1. My dataloaders and models are like:
    loader = DataLoader(dataset, batchsize, shuffle = True, num_workers = 4, prefetch_factor = 2, drop_last = True)
    model = torch.compile(resnet50, mode = 'max-autotune')
  2. I used amp.autocast() to accelerate my training.
  3. The loss function is just nn.CrossEntropyLoss.
  4. I set cudnn.benchmark = True.

The DataLoader is set to use 4 workers with a prefetch factor of 2. It’s possible that the data loading process is not keeping up with the training process, causing the GPU to wait for new data every few batches. You can try the below points

  • Increasing the num_workers to a higher value (e.g., 8 or 16) to see if it improves the situation.
  • Increasing the prefetch_factor so that more samples are preloaded.
  • Monitoring the CPU and disk usage to see if there’s a bottleneck in data loading.
  • You can try disabling amp.autocast() to see if the issue still persists. If the issue disappears, you might need to fine-tune the mixed precision training configuration.
  • You have set cudnn.benchmark = True , which allows CuDNN to automatically select the best algorithm for the current hardware configuration. In some cases, this can cause fluctuations in performance. You can try setting cudnn.benchmark = False and cudnn.deterministic = True to see if this improves the situation
  • Monitor the GPU utilization using a tool like nvidia-smi to see if there are any fluctuations in the GPU usage that correlate with the slowdowns.

Thanks! The training speed did increase after I tried some of these strategies. However, there’s another weird thing: the speed was good in the first several dozens of batches (about 50~100), but after these batches, the speed decreased, and the problem arose again. I have no idea about why this happened.

Could you explain what “fine-tune the mixed precision training configuration” means here, please?

What exactly did you change?

I meant to say trying lower-precision data types like float16 or higher-precision types like float32.

I set cudnn.deterministic = True and increased num_workers then found the speed increased. However, the problem still exists now: the speed is good in first several dozens of batches but after these batches the speed decreases. It seems that about every "the power of 10"th (around 10th, 100th, 1000th, …) batch in one epoch the speed decreases by a half. And after the speed slowing down to some scale, the training gets “stuck” every three batches.

This might be a problem related to your data. I would suggest you inspect your data thoroughly. You can even try training with a smaller subset.

I think your data loading might run into a bottleneck and you could try to profile the DataLoader e.g. as used here.
If you are seeing the slowdown in the DataLoader you could try to profile the data loading pipeline standalone to check what exactly is slowing it down.

Thanks and you are right. The problem is in Dataloaders. Just 4 num_workers and prefetch_factors are not enough. My device has 12 workers and when I tried to use all of them the speed increased a lot. However there’s a new issue: if I use model = torch.compile(model, mode = 'max-autotune') with all 12 workers, then there will be a lot of errors like

  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_inductor/", line 328, in cudagraphify_impl
    static_outputs = model(list(static_inputs))
  File "/tmp/torchinductor_root/vi/", line 2684, in call, buf37, primals_171, primals_172, buf42, buf40, buf41, buf43, 256, 1605632, grid=grid(256), stream=stream0)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_inductor/triton_ops/", line 190, in run
    result = launcher(
  File "<string>", line 6, in launcher
RuntimeError: Triton Error [CUDA]: operation failed due to a previous error during capture

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/syc_cache/", line 71, in <module>
  File "/mnt/syc_cache/", line 20, in main
    train(args, p, epoch)
  File "/mnt/syc_cache/", line 38, in train
    loss = forward(images, labels, args, p)
  File "/mnt/syc_cache/", line 56, in forward
    output = p['model'](images)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_dynamo/", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_dynamo/", line 209, in _fn
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torchvision/models/", line 284, in forward
    def forward(self, x: Tensor) -> Tensor:
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_dynamo/", line 209, in _fn
    return fn(*args, **kwargs)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_functorch/", line 2819, in forward
    return compiled_fn(full_args)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_functorch/", line 1222, in g
    return f(*args)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_functorch/", line 2386, in debug_compiled_function
    return compiled_function(*args)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_functorch/", line 1898, in runtime_wrapper
    all_outs = call_func_with_args(
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_functorch/", line 1247, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_functorch/", line 1222, in g
    return f(*args)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/autograd/", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_functorch/", line 2151, in forward
    fw_outs = call_func_with_args(
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_functorch/", line 1247, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_inductor/", line 248, in run
    return model(new_inputs)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_inductor/", line 265, in run
    compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/_inductor/", line 327, in cudagraphify_impl
    with torch.cuda.graph(graph, stream=stream):
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/cuda/", line 173, in __exit__
  File "/root/miniconda3/envs/myconda/lib/python3.10/site-packages/torch/cuda/", line 79, in capture_end
RuntimeError: CUDA error: operation failed due to a previous error during capture
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.This text will be hidden

; if I just use model = torch.compile(model) with all 12 workers, then there will be no errors. Does max-autotune requires extra workers?

Does the code with with torch.compile(model, mode="max-autotune") if the num_workers are set to 0?
Based on the error message I would guess the failure is unrelated to the DataLoader workers and something fails inside torch.compile. However, the actual error is not displayed and only re-raised:

RuntimeError: Triton Error [CUDA]: operation failed due to a previous error during capture

The code can run with torch.compile(model, mode="max-autotune") with both num_workers = 4, prefetch_factor = 4 and num_workers = 0, prefetch_factor = None. However, it fail with num_workers = 8, prefetch_factor = 8 and higher.

The first Traceback in aforemetioned displayed errors only shows
RuntimeError: Triton Error [CUDA]: operation failed due to a previous error during capture. The actual error is not displayed so I have no idea what causes this.