Best practices for infinite data using IterableDataset

I am training a model using “teacher” distributions which are fully known. That is, I am sampling fresh points at each batch/epoch/repetition etc. from a specified distribution. Each training example comprises an input (shape 1 x D, float), context (shape 1, int), and target (shape 1, float). D is typically between 8 to 64. Batch sizes have been varied between 4 to 4096, with no success.

My issue is that the GPU utilization is typically very low (sub 10%), despite trying to use anywhere between 1 - 8 CPU workers for loading the data, with pin_memory=True, prefetch_factor=2. Since the sampling process (generating new points) is rather simple, would it be better for me to simply generate the points directly on the GPU using torch operations?

Any and all help is truly appreciated!

For reference, here is the DataStream class which generates new points using the sample() method:

class ClassificationDataStream(ContextDataStreamBase):
    def __init__(
        self,
        device: torch.device,
        dim: int,
        n_contexts: int,
        boundaries: torch.Tensor,
        rand_seed: int | None = None,
    ) -> None:
        super().__init__(device, dim, n_contexts, rand_seed)

        self.boundaries = boundaries

    def sample(self, batch_size, sampling_method="gaussian", *args, **kwargs):
        assert sampling_method in [
            "gaussian",
            "uniform",
        ], "Invalid sampling method provided."
        with torch.no_grad():
            if sampling_method == "gaussian":
                x = (
                    torch.randn(batch_size, self.dim, device=self.device).float()
                    / self.dim
                )
            elif sampling_method == "uniform":
                x = (
                    2 * torch.rand(batch_size, self.dim, device=self.device) - 1
                ).float()
            else:
                msg = "Invalid sampling method provided."
                raise ValueError(msg)

            contexts = torch.randint(
                0, self.n_contexts, (batch_size,), device=self.device
            )
            targets = (
                torch.hstack([x, torch.ones((batch_size, 1))])
                * self.boundaries[contexts]
            ).sum(dim=1)

        return x, contexts, targets

Here is the IterableDataset subclass which I feed into the data kwarg of a DataLoader:

class DataStreamIterable(IterableDataset):
    def __init__(
        self,
        device: torch.device,
        task: str,
        dim: int,
        n_contexts: int,
        batch_size: int = 32,
        rand_seed: int | None = None,
        boundaries: torch.Tensor = None,
        local_boundaries: bool = False,
        *args,
        **kwargs,
    ):
        super().__init__()

        self.device = device
        assert task in ["classification", "regression"], "Invalid task provided."
        self.task = task
        self.dim = dim
        self.n_contexts = n_contexts
        self.batch_size = batch_size
        self.rand_seed = rand_seed
        self.task_kwargs = kwargs

        worker_info = get_worker_info()

        if task == "classification":
            if boundaries is None:
                if local_boundaries:
                    self.generate_boundaries_local(
                        **kwargs
                    )  # TODO: Implement this method with von Mises distribution
                else:
                    self.generate_boundaries(**kwargs)
            else:
                self.boundaries = boundaries

            if worker_info is None:
                self.datastream = ClassificationDataStream(
                    device=self.device,
                    dim=self.dim,
                    n_contexts=self.n_contexts,
                    boundaries=self.boundaries,
                    rand_seed=self.rand_seed,
                )
            else:
                worker_id = worker_info.id
                self.datastream = ClassificationDataStream(
                    device=self.device,
                    dim=self.dim,
                    n_contexts=self.n_contexts,
                    boundaries=self.boundaries,
                    rand_seed=self.rand_seed + worker_id,
                )

        elif task == "regression":
            self.functions = nn.ModuleList(
                [ContextMLP(input_size=dim, **kwargs) for _ in range(self.n_contexts)]
            )
            if worker_info is None:
                self.datastream = RegressionDataStream(
                    device=self.device,
                    dim=self.dim,
                    n_contexts=self.n_contexts,
                    functions=self.functions,
                    rand_seed=self.rand_seed,
                )
            else:
                worker_id = worker_info.id
                self.datastream = RegressionDataStream(
                    device=self.device,
                    dim=self.dim,
                    n_contexts=self.n_contexts,
                    functions=self.functions,
                    rand_seed=self.rand_seed + worker_id,
                )

    def __iter__(self):
        while True:
            yield self.datastream.sample(self.batch_size, **self.task_kwargs)