Background
I’m trying to fine-tune a fork of the Segment-Anything model released by Meta on a custom medical imaging dataset. The problem is as follows:
- train data is stored as 304
npz
files on disk. Each file is 1.4GB The total size of these files is something like ~430GB so too much to store in RAM. - The solution I devised was a custom
IterableDataset
that has a buffer. When the buffer drops below the batch_size, it should read from disk to refill the buffer. Meanwhile, it continues serving batches from the buffer - I’m trying to run this with multiprocessing to address the bottleneck of load times
- This is resulting in deadlock-like behavior as I’m increasing the number of workers. With
num_workers=4
I still get very long read times for thenpz
files.
Minimum Reproducible Example
Here is the code that reproduces this behavior:
import os
import torch
import time
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader, IterableDataset
from datetime import datetime
from collections import deque
class NpzIterableDataset(IterableDataset):
def __init__(self, data_root, batch_size):
super().__init__()
self.data_root = data_root
self.batch_size = batch_size
self.npz_files = None
self.embedding_buffer = deque()
self.gt_buffer = deque()
def _load_files(self, file_list):
self.npz_files = file_list
def _load(self):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else -999
if not self.npz_files:
return False
start_time = time.time()
filename = self.npz_files.pop()
print(f"[Worker {worker_id}]: Npz files left: {len(self.npz_files)}")
try:
with np.load(os.path.join(self.data_root, filename)) as npz_data:
embeddings_t = torch.from_numpy(npz_data["img_embeddings"]).float()
gts_t = torch.from_numpy(npz_data["gts"][:, None, :, :]).long()
self.embedding_buffer.extend(embeddings_t)
self.gt_buffer.extend(gts_t)
except Exception as e:
return False
end_time = time.time()
print(f"[Worker {worker_id}]: Load time: {end_time - start_time:.2f}s")
return True
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info else -999
while True:
delay = worker_id * 0.1
time.sleep(delay)
if len(self.embedding_buffer) < self.batch_size:
if not self._load():
break # Stop iteration
for _ in range(self.batch_size):
yield self.embedding_buffer.pop(), self.gt_buffer.pop()
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset # the dataset copy in this worker process
# Configure the dataset to only process a split workload
file_list = os.listdir(dataset.data_root)
np.random.shuffle(file_list)
sublists = [deque(sublist) for sublist in np.array_split(file_list, worker_info.num_workers)]
dataset._load_files(sublists[worker_id])
TASK_PATH = "/home/khans24/medsam_finetuning/finetune_task"
NPZ_TR_PATH = "load/npz_files/train"
BATCH_SIZE = 64
NUM_WORKERS = 4
DEVICE = "cuda:1"
if __name__ == "__main__":
for idx in range(2):
train_dataset = NpzIterableDataset(os.path.join(TASK_PATH, NPZ_TR_PATH), batch_size=BATCH_SIZE)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS,
worker_init_fn=worker_init_fn)
train_loop_length = len(os.listdir(os.path.join(TASK_PATH, NPZ_TR_PATH))) # Number of npz files
train_loop_length = (train_loop_length * 383) // BATCH_SIZE # Number of batches
for step, (image_embedding, gt) in enumerate(tqdm(train_dataloader, total=train_loop_length)):
image_embedding.to(DEVICE)
gt.to(DEVICE)
avg_i = torch.mean(image_embedding)
and here is the output I am logging:
(sam) ~/medsam_finetuning$ python minimum_repr.py
0%| | 0/1819 [00:00<?, ?it/s][Worker 0]: Npz files left: 75
[Worker 1]: Npz files left: 75
[Worker 2]: Npz files left: 75
[Worker 3]: Npz files left: 75
[Worker 3]: Load time: 8.71s
[Worker 0]: Load time: 9.05s
[Worker 2]: Load time: 9.02s
0%| | 1/1819 [00:19<9:49:08, 19.44s/it][Worker 1]: Load time: 85.61s
1%|▌ | 12/1819 [01:28<27:48, 1.08it/s][Worker 0]: Npz files left: 74
1%|▋ | 13/1819 [01:28<21:25, 1.40it/s][Worker 1]: Npz files left: 74
1%|▋ | 14/1819 [01:29<17:06, 1.76it/s][Worker 2]: Npz files left: 74
1%|▊ | 16/1819 [01:29<11:50, 2.54it/s][Worker 3]: Npz files left: 74
1%|▉ | 20/1819 [01:30<08:14, 3.64it/s][Worker 0]: Load time: 9.15s
1%|▉ | 21/1819 [01:38<1:16:11, 2.54s/it][Worker 2]: Load time: 9.09s
[Worker 3]: Load time: 8.81s
[Worker 1]: Load time: 108.77s
2%|█▋ | 36/1819 [03:21<12:56, 2.30it/s][Worker 0]: Npz files left: 73
2%|█▊ | 37/1819 [03:21<11:07, 2.67it/s][Worker 1]: Npz files left: 73
2%|█▊ | 38/1819 [03:21<09:58, 2.98it/s][Worker 2]: Npz files left: 73
2%|█▉ | 40/1819 [03:22<08:07, 3.65it/s][Worker 3]: Npz files left: 73
2%|██▏ | 44/1819 [03:23<06:56, 4.26it/s][Worker 0]: Load time: 8.94s
2%|██▏ | 45/1819 [03:30<1:13:21, 2.48s/it][Worker 2]: Load time: 9.34s
[Worker 1]: Load time: 247.14s
3%|██▏ | 47/1819 [07:29<25:17:00, 51.37s/it][Worker 3]: Load time: 248.15s
Furthermore, if I kill the program when it hangs for longer than 120 seconds, this is where the main thread is stopped:
Traceback (most recent call last):
File "/home/khans24/medsam_finetuning/train_debug.py", line 175, in <module>
for step, (image_embedding, gt) in enumerate(train_dataloader):
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
data = self._next_data()
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data
idx, data = self._get_data()
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1295, in _get_data
success, data = self._try_get_data()
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/multiprocessing/queues.py", line 113, in get
if not self._poll(timeout):
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/multiprocessing/connection.py", line 258, in poll
return self._poll(timeout)
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/multiprocessing/connection.py", line 425, in _poll
r = wait([self], timeout)
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/multiprocessing/connection.py", line 932, in wait
ready = selector.select(timeout)
File "/home/khans24/miniconda3/envs/sam/lib/python3.10/selectors.py", line 416, in select
fd_event_list = self._selector.poll(timeout)
Other Checks
To make sure this was not an IOPS issue, here is the output of iostat -dx 1
while the program is running inside a tmux
session:
Device: rrqm/s wrqm/s r/s w/s rkB/s wkB/s avgrq-sz avgqu-sz await r_await w_await svctm %util
sda 0.00 0.06 3.96 0.71 68.34 16.72 36.40 0.00 0.09 0.06 0.24 0.03 0.02
dm-0 0.00 0.00 3.97 0.78 68.34 16.72 35.86 0.00 0.09 0.06 0.25 0.03 0.02
dm-1 0.00 0.00 0.00 0.00 0.00 0.00 48.19 0.00 0.15 0.15 0.00 0.08 0.00
Device: rrqm/s wrqm/s r/s w/s rkB/s wkB/s avgrq-sz avgqu-sz await r_await w_await svctm %util
sda 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
dm-0 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
dm-1 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
Device: rrqm/s wrqm/s r/s w/s rkB/s wkB/s avgrq-sz avgqu-sz await r_await w_await svctm %util
sda 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
dm-0 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
dm-1 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00 0.00
So I don’t think this is an issue of too many read/writes from disk.
Problem/Questions
- Based on this, I suspect that it’s blocking waiting for the
npz
file object to be readable. But given my implementation ofworker_init_fn
, the list of npz files sent to each copy ofNpzIterableDataset
inside each worker should be a disjoint subset. So what shared resource here is each thread blocking for? - Open to any suggestions for a different approach to solve my issue of streaming in data from disk
- Also open to any multiprocess debugging solutions. I’ve tried puDB but this tends to be unresponsive and crashes whenever I use the command line feature inside of it.
Misc
- Another relevant thread on this forum where the questioner seems to be having the same issue with parallelizing an
IterableDataset
. I don’t think the solution works for me though since my dataset is very large.
Lastly, thank you for reading this post and for your thoughts and suggestions!