Data.random_split and dataloader get num_samples=0 on GPU

I am doing some model training on a remote cluster that has a GPU (my local machine does not, which makes debugging slow and tricky). I have tested that my code creates DataLoaders as expected locally, but when I send the job off, I get the following error and traceback

/home/jdivers/nsclc/jobs/big_comet_w_atlases/
/home/jdivers/.conda/envs/dl_env/lib/python3.12/site-packages/torch/utils/data/dataset.py:449: UserWarning: Length of split at index 0 is 0. This might result in an empty dataset.
  warnings.warn(f"Length of split at index {i} is 0. "
/home/jdivers/.conda/envs/dl_env/lib/python3.12/site-packages/torch/utils/data/dataset.py:449: UserWarning: Length of split at index 1 is 0. This might result in an empty dataset.
  warnings.warn(f"Length of split at index {i} is 0. "
/home/jdivers/.conda/envs/dl_env/lib/python3.12/site-packages/torch/utils/data/dataset.py:449: UserWarning: Length of split at index 2 is 0. This might result in an empty dataset.
  warnings.warn(f"Length of split at index {i} is 0. "
Traceback (most recent call last):
  File "/scr1/398185/big_comet_w_atlases.py", line 150, in <module>
    main()
  File "/scr1/398185/big_comet_w_atlases.py", line 54, in main
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jdivers/.conda/envs/dl_env/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 350, in __init__
    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jdivers/.conda/envs/dl_env/lib/python3.12/site-packages/torch/utils/data/sampler.py", line 143, in __init__
    raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
ValueError: num_samples should be a positive integer value, but got num_samples=0

I am using a custom dataset class that I wrote a method for to send the entire dataset (as it is loaded and cached) to device, in this case, the GPU. That method is here:

    def to(self, device):
        # Move caches to device
        if self.index_cache is not None:
            self.index_cache = self.index_cache.to(device)
            for i, idx in enumerate(self.index_cache):
                self.index_cache[i] = idx.to(device)
            for i, x_cache in enumerate(self.shared_x):
                self.shared_x[i] = x_cache.to(device)
            self.shared_y = self.shared_y.to(device)
            for i, y in enumerate(self.shared_y):
                self.shared_y[i] = y.to(device)

        # Move any self-held tensors to device for ops compatibility
        self.scalars = self.scalars.to(device) if self.scalars is not None else None

        # Update device for future items
        self.device = device

As you can tell, I am using and optional cache, as well. I’ve checked these on the CPU, as well, to make sure they are populating as expected. Though even if they weren’t, my dataset would still look like a non-zero sized set of empty tensors. Here’s the cache creation method, in case its relevant:

    def _open_cache(self, x, y):
        # Setup shared memory arrays (i.e. caches that are compatible with multiple workers)
        # negative initialization ensure no overlap with actual cached indices
        cache_len = len(self.all_atlases) if self._use_atlas else len(self.all_fovs)
        index_cache_base = mp.Array(ctypes.c_int, cache_len * [-1])
        shared_x_base = cache_len * [mp.Array(ctypes.c_float, 0)]

        # Label-size determines cache size, so if no label is set, we will fill cache with -999999 at __getitem__
        match self.label:
            case 'Response' | 'Metastases' | None:
                shared_y_base = mp.Array(ctypes.c_float, cache_len * [-1])
                y_shape = ()
            case 'Mask':
                shared_y_base = mp.Array(ctypes.c_float, int(cache_len * np.prod(y.shape)))
                y_shape = tuple(y.shape)
            case _:
                raise Exception('An unrecognized label is in use that is blocking the cache from initializing. '
                                'Update label attribute of dataset and try again.')

        # Convert all arrays to desired data structure
        self.index_cache = convert_mp_to_torch(index_cache_base, (cache_len,), device=self.device)
        self.shared_x = [convert_mp_to_torch(x_base, 0, device=self.device) for x_base in shared_x_base]
        self.shared_y = convert_mp_to_torch(shared_y_base, (cache_len,) + y_shape, device=self.device)
        print('Cache opened.')

Note, the function convert_mp_to_torch is just a simply conversion from an mp.Array through an np.array, to a torch.tensor that gets reshaped appropriately, so the resultant caches are all torch.tensors.

I’m not expecting anyone to solve this for me, given the complexity of my dataset, but if anyone has any ideas where to begin to look. Because I am only getting an error on the remote cluster, any ideas of where to focus my debugging would be much appreciated!

It seems the remote machine might have issues creating the dataset in the first place as seen in these warnings:

/home/jdivers/.conda/envs/dl_env/lib/python3.12/site-packages/torch/utils/data/dataset.py:449: UserWarning: Length of split at index 0 is 0. This might result in an empty dataset.
  warnings.warn(f"Length of split at index {i} is 0. "
/home/jdivers/.conda/envs/dl_env/lib/python3.12/site-packages/torch/utils/data/dataset.py:449: UserWarning: Length of split at index 1 is 0. This might result in an empty dataset.
  warnings.warn(f"Length of split at index {i} is 0. "
/home/jdivers/.conda/envs/dl_env/lib/python3.12/site-packages/torch/utils/data/dataset.py:449: UserWarning: Length of split at index 2 is 0. This might result in an empty dataset.
  warnings.warn(f"Length of split at index {i} is 0. "

All splits are empty, so make sure samples are found (e.g. check file paths for their validity).

:man_facepalming: I thought that wasn’t an option, but figured I better double-check, and sure enough, the data is not being properly copied…I never considered it could be as simple as a job script failure. Thanks for taking a second to help!