I am training a model using “teacher” distributions which are fully known. That is, I am sampling fresh points at each batch/epoch/repetition etc. from a specified distribution. Each training example comprises an input (shape 1 x D, float), context (shape 1, int), and target (shape 1, float). D is typically between 8 to 64. Batch sizes have been varied between 4 to 4096, with no success.
My issue is that the GPU utilization is typically very low (sub 10%), despite trying to use anywhere between 1 - 8 CPU workers for loading the data, with pin_memory=True, prefetch_factor=2. Since the sampling process (generating new points) is rather simple, would it be better for me to simply generate the points directly on the GPU using torch operations?
Any and all help is truly appreciated!
For reference, here is the DataStream class which generates new points using the sample() method:
class ClassificationDataStream(ContextDataStreamBase):
def __init__(
self,
device: torch.device,
dim: int,
n_contexts: int,
boundaries: torch.Tensor,
rand_seed: int | None = None,
) -> None:
super().__init__(device, dim, n_contexts, rand_seed)
self.boundaries = boundaries
def sample(self, batch_size, sampling_method="gaussian", *args, **kwargs):
assert sampling_method in [
"gaussian",
"uniform",
], "Invalid sampling method provided."
with torch.no_grad():
if sampling_method == "gaussian":
x = (
torch.randn(batch_size, self.dim, device=self.device).float()
/ self.dim
)
elif sampling_method == "uniform":
x = (
2 * torch.rand(batch_size, self.dim, device=self.device) - 1
).float()
else:
msg = "Invalid sampling method provided."
raise ValueError(msg)
contexts = torch.randint(
0, self.n_contexts, (batch_size,), device=self.device
)
targets = (
torch.hstack([x, torch.ones((batch_size, 1))])
* self.boundaries[contexts]
).sum(dim=1)
return x, contexts, targets
Here is the IterableDataset subclass which I feed into the data kwarg of a DataLoader:
class DataStreamIterable(IterableDataset):
def __init__(
self,
device: torch.device,
task: str,
dim: int,
n_contexts: int,
batch_size: int = 32,
rand_seed: int | None = None,
boundaries: torch.Tensor = None,
local_boundaries: bool = False,
*args,
**kwargs,
):
super().__init__()
self.device = device
assert task in ["classification", "regression"], "Invalid task provided."
self.task = task
self.dim = dim
self.n_contexts = n_contexts
self.batch_size = batch_size
self.rand_seed = rand_seed
self.task_kwargs = kwargs
worker_info = get_worker_info()
if task == "classification":
if boundaries is None:
if local_boundaries:
self.generate_boundaries_local(
**kwargs
) # TODO: Implement this method with von Mises distribution
else:
self.generate_boundaries(**kwargs)
else:
self.boundaries = boundaries
if worker_info is None:
self.datastream = ClassificationDataStream(
device=self.device,
dim=self.dim,
n_contexts=self.n_contexts,
boundaries=self.boundaries,
rand_seed=self.rand_seed,
)
else:
worker_id = worker_info.id
self.datastream = ClassificationDataStream(
device=self.device,
dim=self.dim,
n_contexts=self.n_contexts,
boundaries=self.boundaries,
rand_seed=self.rand_seed + worker_id,
)
elif task == "regression":
self.functions = nn.ModuleList(
[ContextMLP(input_size=dim, **kwargs) for _ in range(self.n_contexts)]
)
if worker_info is None:
self.datastream = RegressionDataStream(
device=self.device,
dim=self.dim,
n_contexts=self.n_contexts,
functions=self.functions,
rand_seed=self.rand_seed,
)
else:
worker_id = worker_info.id
self.datastream = RegressionDataStream(
device=self.device,
dim=self.dim,
n_contexts=self.n_contexts,
functions=self.functions,
rand_seed=self.rand_seed + worker_id,
)
def __iter__(self):
while True:
yield self.datastream.sample(self.batch_size, **self.task_kwargs)