Trouble with concurrency and IterableDataset

Background

I’m trying to fine-tune a fork of the Segment-Anything model released by Meta on a custom medical imaging dataset. The problem is as follows:

  • train data is stored as 304 npz files on disk. Each file is 1.4GB The total size of these files is something like ~430GB so too much to store in RAM.
  • The solution I devised was a custom IterableDataset that has a buffer. When the buffer drops below the batch_size, it should read from disk to refill the buffer. Meanwhile, it continues serving batches from the buffer
  • I’m trying to run this with multiprocessing to address the bottleneck of load times
  • This is resulting in deadlock-like behavior as I’m increasing the number of workers. With num_workers=4 I still get very long read times for the npz files.

Minimum Reproducible Example

Here is the code that reproduces this behavior:

import os
import torch
import time

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data import DataLoader, IterableDataset
from datetime import datetime
from collections import deque

class NpzIterableDataset(IterableDataset):
    def __init__(self, data_root, batch_size):
        super().__init__()
        self.data_root = data_root
        self.batch_size = batch_size
        self.npz_files = None
        self.embedding_buffer = deque()
        self.gt_buffer = deque()

    def _load_files(self, file_list):
        self.npz_files = file_list

    def _load(self):
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info else -999
        if not self.npz_files:
            return False

        start_time = time.time()
        filename = self.npz_files.pop()
        print(f"[Worker {worker_id}]: Npz files left: {len(self.npz_files)}")

        try:
            with np.load(os.path.join(self.data_root, filename)) as npz_data:
                embeddings_t = torch.from_numpy(npz_data["img_embeddings"]).float()
                gts_t = torch.from_numpy(npz_data["gts"][:, None, :, :]).long()
                self.embedding_buffer.extend(embeddings_t)
                self.gt_buffer.extend(gts_t)
        except Exception as e:
            return False

        end_time = time.time()

        print(f"[Worker {worker_id}]: Load time: {end_time - start_time:.2f}s")
        return True

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        worker_id = worker_info.id if worker_info else -999
        while True:
            delay = worker_id * 0.1
            time.sleep(delay)
            if len(self.embedding_buffer) < self.batch_size:
                if not self._load():
                    break  # Stop iteration

            for _ in range(self.batch_size):
                yield self.embedding_buffer.pop(), self.gt_buffer.pop()

def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset  # the dataset copy in this worker process

    # Configure the dataset to only process a split workload 
    file_list = os.listdir(dataset.data_root)
    np.random.shuffle(file_list)
    sublists = [deque(sublist) for sublist in np.array_split(file_list, worker_info.num_workers)]
    dataset._load_files(sublists[worker_id])

TASK_PATH = "/home/khans24/medsam_finetuning/finetune_task"
NPZ_TR_PATH = "load/npz_files/train"
BATCH_SIZE = 64
NUM_WORKERS = 4
DEVICE = "cuda:1"

if __name__ == "__main__":
    for idx in range(2):
        train_dataset = NpzIterableDataset(os.path.join(TASK_PATH, NPZ_TR_PATH), batch_size=BATCH_SIZE)
        train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, 
                                      worker_init_fn=worker_init_fn)

        train_loop_length = len(os.listdir(os.path.join(TASK_PATH, NPZ_TR_PATH)))  # Number of npz files
        train_loop_length = (train_loop_length * 383) // BATCH_SIZE  # Number of batches

        for step, (image_embedding, gt) in enumerate(tqdm(train_dataloader, total=train_loop_length)):
            image_embedding.to(DEVICE)
            gt.to(DEVICE)
            avg_i = torch.mean(image_embedding)

and here is the output I am logging:

(sam) ~/medsam_finetuning$ python minimum_repr.py
  0%|                                                                                                 | 0/1819 [00:00<?, ?it/s][Worker 0]: Npz files left: 75
[Worker 1]: Npz files left: 75
[Worker 2]: Npz files left: 75
[Worker 3]: Npz files left: 75
[Worker 3]: Load time: 8.71s
[Worker 0]: Load time: 9.05s
[Worker 2]: Load time: 9.02s
  0%|                                                                                       | 1/1819 [00:19<9:49:08, 19.44s/it][Worker 1]: Load time: 85.61s
  1%|▌                                                                                       | 12/1819 [01:28<27:48,  1.08it/s][Worker 0]: Npz files left: 74
  1%|▋                                                                                       | 13/1819 [01:28<21:25,  1.40it/s][Worker 1]: Npz files left: 74
  1%|▋                                                                                       | 14/1819 [01:29<17:06,  1.76it/s][Worker 2]: Npz files left: 74
  1%|▊                                                                                       | 16/1819 [01:29<11:50,  2.54it/s][Worker 3]: Npz files left: 74
  1%|▉                                                                                       | 20/1819 [01:30<08:14,  3.64it/s][Worker 0]: Load time: 9.15s
  1%|▉                                                                                     | 21/1819 [01:38<1:16:11,  2.54s/it][Worker 2]: Load time: 9.09s
[Worker 3]: Load time: 8.81s
[Worker 1]: Load time: 108.77s
  2%|█▋                                                                                      | 36/1819 [03:21<12:56,  2.30it/s][Worker 0]: Npz files left: 73
  2%|█▊                                                                                      | 37/1819 [03:21<11:07,  2.67it/s][Worker 1]: Npz files left: 73
  2%|█▊                                                                                      | 38/1819 [03:21<09:58,  2.98it/s][Worker 2]: Npz files left: 73
  2%|█▉                                                                                      | 40/1819 [03:22<08:07,  3.65it/s][Worker 3]: Npz files left: 73
  2%|██▏                                                                                     | 44/1819 [03:23<06:56,  4.26it/s][Worker 0]: Load time: 8.94s
  2%|██▏                                                                                   | 45/1819 [03:30<1:13:21,  2.48s/it][Worker 2]: Load time: 9.34s
[Worker 1]: Load time: 247.14s
  3%|██▏                                                                                  | 47/1819 [07:29<25:17:00, 51.37s/it][Worker 3]: Load time: 248.15s

Furthermore, if I kill the program when it hangs for longer than 120 seconds, this is where the main thread is stopped:

Traceback (most recent call last):
  File "/home/khans24/medsam_finetuning/train_debug.py", line 175, in <module>
    for step, (image_embedding, gt) in enumerate(train_dataloader):
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
    data = self._next_data()
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data
    idx, data = self._get_data()
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1295, in _get_data
    success, data = self._try_get_data()
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/multiprocessing/queues.py", line 113, in get
    if not self._poll(timeout):
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/multiprocessing/connection.py", line 258, in poll
    return self._poll(timeout)
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/multiprocessing/connection.py", line 425, in _poll
    r = wait([self], timeout)
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/multiprocessing/connection.py", line 932, in wait
    ready = selector.select(timeout)
  File "/home/khans24/miniconda3/envs/sam/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)

Other Checks

To make sure this was not an IOPS issue, here is the output of iostat -dx 1 while the program is running inside a tmux session:


Device:         rrqm/s   wrqm/s     r/s     w/s    rkB/s    wkB/s avgrq-sz avgqu-sz   await r_await w_await  svctm  %util
sda               0.00     0.06    3.96    0.71    68.34    16.72    36.40     0.00    0.09    0.06    0.24   0.03   0.02
dm-0              0.00     0.00    3.97    0.78    68.34    16.72    35.86     0.00    0.09    0.06    0.25   0.03   0.02
dm-1              0.00     0.00    0.00    0.00     0.00     0.00    48.19     0.00    0.15    0.15    0.00   0.08   0.00

Device:         rrqm/s   wrqm/s     r/s     w/s    rkB/s    wkB/s avgrq-sz avgqu-sz   await r_await w_await  svctm  %util
sda               0.00     0.00    0.00    0.00     0.00     0.00     0.00     0.00    0.00    0.00    0.00   0.00   0.00
dm-0              0.00     0.00    0.00    0.00     0.00     0.00     0.00     0.00    0.00    0.00    0.00   0.00   0.00
dm-1              0.00     0.00    0.00    0.00     0.00     0.00     0.00     0.00    0.00    0.00    0.00   0.00   0.00

Device:         rrqm/s   wrqm/s     r/s     w/s    rkB/s    wkB/s avgrq-sz avgqu-sz   await r_await w_await  svctm  %util
sda               0.00     0.00    0.00    0.00     0.00     0.00     0.00     0.00    0.00    0.00    0.00   0.00   0.00
dm-0              0.00     0.00    0.00    0.00     0.00     0.00     0.00     0.00    0.00    0.00    0.00   0.00   0.00
dm-1              0.00     0.00    0.00    0.00     0.00     0.00     0.00     0.00    0.00    0.00    0.00   0.00   0.00

So I don’t think this is an issue of too many read/writes from disk.

Problem/Questions

  • Based on this, I suspect that it’s blocking waiting for the npz file object to be readable. But given my implementation of worker_init_fn, the list of npz files sent to each copy of NpzIterableDataset inside each worker should be a disjoint subset. So what shared resource here is each thread blocking for?
  • Open to any suggestions for a different approach to solve my issue of streaming in data from disk
  • Also open to any multiprocess debugging solutions. I’ve tried puDB but this tends to be unresponsive and crashes whenever I use the command line feature inside of it.

Misc

  • Another relevant thread on this forum where the questioner seems to be having the same issue with parallelizing an IterableDataset. I don’t think the solution works for me though since my dataset is very large.

Lastly, thank you for reading this post and for your thoughts and suggestions!

Here is a minimal example that I think gets at the heart of my problem here: loading in npz files in a multiprocess context:

import os
import time

import numpy as np
import multiprocessing as mp

def load_file(filepath):
    start_path = "/home/khans24/beegfs/medsam_finetuning/npz_files/train"
    start_time = time.time()
    npz = np.load(os.path.join(start_path, filepath))
    embedding = npz['img_embeddings']
    gt = npz['gts']
    end_time = time.time()
    print(f"LOADED FILE: {filepath} in {end_time - start_time} seconds")
    return embedding, gt

start_path = '/home/khans24/beegfs/medsam_finetuning/npz_files/train'
pp = mp.Pool(4)
npz_files = os.listdir(start_path)
results = pp.map(load_file, npz_files[:4])
pp.close()
pp.join()

print(results)

Running the above with 2 workers, loads the files taking 13s each. Running it with 4 workers loads the 4 files in 30s each, but causes the program to hang and never execute the last print statement.

Are you running in debug mode?

No, this is just from running the script at terminal.

The main issue I see is that each npz file is around 1.5gb each. You’re heavily IO-limited. If you’re not already storing this on an SSD or NVME drive, I would move it there.

I’m not fully sure how you’re planning to use this data. You could also use mmap_mode in the np.load. See here for a guide on how to use it with pytorch dataset.

Another option to speed things up could be to cache your file as .pt with torch.save to avoid moving the large file from numpy to torch.

One final thing I could suggest is to use TorchData and use in_memory_cache if you have the ram.