Speeding up loading data from spark

I’ve been working on using petastorm to train PyTorch models from spark dataframes (somewhat following this guide). I’m curious if there are any ways I can speed up data loading.

Here’s a basic overview of my current flow. df_train is a spark dataframe with three columns: x1 (float), x2 (binary 0,1), y (float). I’m using pyspark.

x_feat = ['x1', 'x2']
y_name = 'y'
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, "file:///dbfs/tmp/petastorm/cache")
converter_train = make_spark_converter(df_train)

with converter_train.make_torch_dataloader(batch_size=bs) as train_dataloader:
    train_dataloader_iter = iter(train_dataloader)
    steps_per_epoch = len(converter_train) // bs
    for step in range(steps_per_epoch):
      pd_batch = next(train_dataloader_iter)
      pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)
      inputs = pd_batch['features'].to(device)
      labels = pd_batch[y_name].to(device)
      ... # modeling and stuff

pd_batch is a dictionary with an entry for each column in the original df_train. My concern is that the torch operations might not be optimal. Something else I tried was first creating an array column in my spark dataframe for x1 and x2. I was surprised to find that each epoch was more than 2 times slower than the above strategy.

df_train = df_train.withColumn("features", array("x1", 'x2')).select('features', 'y')
# remainder same as above except torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1) was removed

Are there any improvements I can make here?

2 Likes