Hi all!
I have a large time series database that doesn’t fit in memory. It’s composed of time series of varying length that are stored in a given folder in parquet format. What I want to do is use a sliding window with a fixed size to create training samples for each time series. Given that each time series has an arbitrary length, the number of samples created by the sliding window approach varies. I also do some other pre-processing such as normalising each window and further processing to put it in the format the model expects. In order to train using the desired batch_size
, I do some extra checks and munging to yield batches of the right size. Finally, I’m training the model using DDP across 4 GPUs.
I have the following working code:
# Some globals defining window size and some other hyperparameters
window_size = 1024
[...]
class LocalIterableDataset(IterableDataset):
def __init__(self, path):
self.path = Path(path)
self.files = list(self.path.glob("*.parquet"))
# These below are used to split the dataset across GPUs without data repetition as described in the docs
self.start = 0
self.end = len(self.files)
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None: # multiple processes
per_worker = int(
math.ceil((self.end - self.start) / float(worker_info.num_workers))
)
iter_start = self.start + worker_info.id * per_worker
iter_end = min(iter_start + per_worker, self.end)
iter_files = self.files[iter_start:iter_end]
else: # single process
iter_files = self.files
for file in iter_files:
window_inputs, window_outputs = self.process_file(file)
for i in range(0, len(window_inputs), batch_size):
# Only return full batches. This is avoid a torch.stack dimensionality error.
# TODO: revisit this to make sure we use all the data.
if i + batch_size > len(window_inputs):
break
yield {
"input_data": torch.from_numpy(
window_inputs[i : i + batch_size]
).to(torch.float32),
"targets": torch.from_numpy(window_outputs[i : i + batch_size]).to(
torch.float32
),
}
def process_file(self, file_path):
"""Read parquet file and process."""
df = pd.read_parquet(file_path)
inputs, outputs = self.create_rolling_windows(df)
return (inputs, outputs)
def create_rolling_windows(self, df):
"""Roll window, normalise, and split into inputs and outputs."""
if len(df) < self.window_size:
return (np.array([]), np.array([]))
windows = []
for col in df.columns:
ts = df[col].dropna().to_numpy()
for idx in range(len(ts) - self.window_size + 1):
segment = ts[idx : idx + self.window_size]
normalised_segment = (
(segment - np.mean(segment)) / np.std(segment)
if np.std(segment) != 0
else np.zeros_like(segment)
)
windows.append(normalised_segment)
# process each window to the format expected by the model
processed_windows = [
self.split_ts_into_inputs_and_outputs(
window, input_patch_length, output_length, max_input_patches
)
for window in windows
]
# separate inputs and outputs
inputs = np.array([window[0] for window in processed_windows])
outputs = np.array([window[1] for window in processed_windows])
return (inputs, outputs)
def split_ts_into_inputs_and_outputs(
self, ts, input_patch_length, output_length, max_input_patches
):
"""This transforms a data window to the format expected by the model."""
input_ts_segment = ts[: input_patch_length * max_input_patches]
output_patches = []
for patch_idx in range(max_input_patches):
output_patch = ts[
(patch_idx + 1)
* input_patch_length : (patch_idx + 1)
* input_patch_length
+ output_length
]
output_patches.append(output_patch)
return (input_ts_segment, np.array(output_patches))
train_loader = DataLoader(
train_dataset,
batch_size=None,
shuffle=False,
)
[...]
trainer = L.Trainer(
accelerator="gpu",
num_nodes=1,
devices=4,
strategy="ddp",
)
trainer.fit(
model,
train_dataloaders=train_loader,
)
This works and I trains fine on the 4 GPUs but, in its current form, there seems to be a memory duplication issue across GPUs, and across dataloaders if I use num_workers
. I’ve read about this (torch.utils.data — PyTorch 2.2 documentation) but cannot figure out how to apply the recommended practices to my use-case.
Any pointers welcome!