Errors doing inference on PySpark Dataframe with Huggingface models

I am able to run the following code on a PySpark Dataframe containing a column of text:

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelWithLMHead.from_pretrained("gpt2")
str2tokens = udf(lambda s: tokenizer.tokenize(s, return_tensors='pt'), ArrayType(StringType()))
tokens2ids = udf(lambda s: tokenizer.convert_tokens_to_ids(s), ArrayType(IntegerType()))

So I am able to load PyTorch and Huggingface libraries, at least. The problem comes when I try to do inference and write out the logits using the following udf:

vectorize_ids = udf(lambda s: model(tensor(s)).logits.mean(axis=0).detach().numpy(), ArrayType(DoubleType()))

I am getting Java heap space errors even running on tiny file sizes:

21/01/15 13:17:00 WARN Utils: Your hostname, Evans-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.7.76 instead (on interface en0)
21/01/15 13:17:00 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
21/01/15 13:17:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/01/15 13:17:32 WARN Utils: Suppressing exception in finally: Java heap space
java.lang.OutOfMemoryError: Java heap space
	at java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:57)
	at java.nio.ByteBuffer.allocate(ByteBuffer.java:335)
	at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$3.apply(TorrentBroadcast.scala:286)
	at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$3.apply(TorrentBroadcast.scala:286)
	at org.apache.spark.util.io.ChunkedByteBufferOutputStream.allocateNewChunkIfNeeded(ChunkedByteBufferOutputStream.scala:87)
	at org.apache.spark.util.io.ChunkedByteBufferOutputStream.write(ChunkedByteBufferOutputStream.scala:75)
	at net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:220)
	at net.jpountz.lz4.LZ4BlockOutputStream.finish(LZ4BlockOutputStream.java:252)
	at net.jpountz.lz4.LZ4BlockOutputStream.close(LZ4BlockOutputStream.java:190)
	at java.io.ObjectOutputStream$BlockDataOutputStream.close(ObjectOutputStream.java:1828)
	at java.io.ObjectOutputStream.close(ObjectOutputStream.java:742)
	at org.apache.spark.serializer.JavaSerializationStream.close(JavaSerializer.scala:57)
	at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$blockifyObject$1.apply$mcV$sp(TorrentBroadcast.scala:293)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1369)
	at org.apache.spark.broadcast.TorrentBroadcast$.blockifyObject(TorrentBroadcast.scala:292)
	at org.apache.spark.broadcast.TorrentBroadcast.writeBlocks(TorrentBroadcast.scala:127)
	at org.apache.spark.broadcast.TorrentBroadcast.<init>(TorrentBroadcast.scala:88)
	at org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:34)
	at org.apache.spark.broadcast.BroadcastManager.newBroadcast(BroadcastManager.scala:62)
	at org.apache.spark.SparkContext.broadcast(SparkContext.scala:1489)
	at org.apache.spark.api.java.JavaSparkContext.broadcast(JavaSparkContext.scala:650)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)

Any idea if there’s something fundamentally wrong with this approach? I have tried using bigger instances on Amazon EMR with no luck as well.

One row udf is pretty slow since the model state_dict() needs to be loaded for each row. I’m trying to use pandas_udf to speed this up, since all the operations can be vectorized efficiently in pandas/pytorch.

I’ve looked at this databricks post for inspiration, but it’s doesn’t correspond exactly to my use case since I want to run prediction on an existing pyspark dataframe.

I can get it to work using one row udf in this simple example:

import torch
import torch.nn as nn
from pyspark.sql.functions import col, pandas_udf, PandasUDFType, udf
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import ArrayType, FloatType, DoubleType
import pandas as pd
import numpy as np

spark = SparkSession.builder.master('local[*]') \
    .appName("model_training") \
    .getOrCreate()

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.w = nn.Linear(5, 1)

    def forward(self, x):
        return self.w(x)

net = Net()
bc_model_state = spark.sparkContext.broadcast(net.state_dict())


df = spark.sparkContext.parallelize([[np.random.rand() for i in range(5)] for j in range(10)]).toDF()
df = df.withColumn('features', F.array([F.col(f"_{i}") for i in range(1, 6)]))

def get_model_for_eval():
  # Broadcast the model state_dict
  net.load_state_dict(bc_model_state.value)
  net.eval()
  return net

def one_row_predict(x):
    model = get_model_for_eval()
    t = torch.tensor(x, dtype=torch.float32)
    prediction = model(t).cpu().detach().item()
    return prediction

one_row_udf = udf(one_row_predict, FloatType())
df = df.withColumn('pred_one_row', one_row_udf(col('features')))
df.show()

Output:

+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+
|                  _1|                 _2|                 _3|                 _4|                 _5|            features|pred_one_row|
+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+
|  0.8447505355266759| 0.3938414671838497|0.46347383092447003| 0.7694022276208854| 0.6152606009215115|[0.84475053552667...| 0.025048971|
|0.023782157504950607| 0.6434186254505012| 0.4090423037706754| 0.5466917794921007| 0.7855157903802007|[0.02378215750495...|  0.19694215|
|  0.5057589877333257| 0.7186078182786649| 0.9123361330966105|  0.601837718628886| 0.0773272396167538|[0.50575898773332...|    0.278222|
|  0.2815336141913932| 0.5196112020157087| 0.9646444599173869|0.04844988843812004|0.35445251642633047|[0.28153361419139...|  0.10699606|
|  0.3896101050146765|0.38732747821339863| 0.8516864705178889| 0.2500977280156421| 0.7781221754566505|[0.38961010501467...| -0.08206403|
|  0.8223344715797269| 0.9089425281658239|0.10088026161623431| 0.9920995834835098|0.40665125930441104|[0.82233447157972...|   0.3565607|
| 0.31167413110257425| 0.9778009876605741| 0.4717549025588036|0.24563879994222826| 0.7594244867194454|[0.31167413110257...|  0.18897778|
|  0.5667657426129576| 0.5383639427018171| 0.2983527299596511|0.18914810241640534|0.47854422807435326|[0.56676574261295...|  0.17796803|
|  0.6419824467244137|0.03992370080139418|0.38462617679839173|  0.709487894249459|0.23020927682221126|[0.64198244672441...|  0.15635887|
|  0.7972928622000178| 0.7700992684264264| 0.4387404431803098| 0.1340696629092989| 0.7072213018683782|[0.79729286220001...|   0.0500246|
+--------------------+-------------------+-------------------+-------------------+-------------------+--------------------+------------+

Trying to do the same thing with in a vectorized way, this works:

def batch_predict(x):
    model = get_model_for_eval()
    xp = np.vstack(x)
    t = torch.tensor(xp, dtype=torch.float32)
    prediction = model(t).cpu().detach().numpy().flatten()
    return pd.Series(prediction)

df_pd = df.toPandas()
x = df_pd['features']
print(batch_predict(x))

But running it inside a pandas_udf fails:

batch_udf = pandas_udf(batch_predict, FloatType())
df = df.withColumn('pred_batch', batch_udf(col('features')))
df.show()

with:

20/02/11 10:13:01 ERROR Executor: Exception in task 2.0 in stage 1.0 (TID 3)
java.lang.IllegalArgumentException
    at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)
    at org.apache.arrow.vector.ipc.message.MessageSerializer.readMessage(MessageSerializer.java:543)
    at org.apache.arrow.vector.ipc.message.MessageChannelReader.readNext(MessageChannelReader.java:58)
    at org.apache.arrow.vector.ipc.ArrowStreamReader.readSchema(ArrowStreamReader.java:132)
    at org.apache.arrow.vector.ipc.ArrowReader.initialize(ArrowReader.java:181)
    at org.apache.arrow.vector.ipc.ArrowReader.ensureInitialized(ArrowReader.java:172)
    at org.apache.arrow.vector.ipc.ArrowReader.getVectorSchemaRoot(ArrowReader.java:65)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$anon$1.read(ArrowPythonRunner.scala:162)
    at org.apache.spark.sql.execution.python.ArrowPythonRunner$anon$1.read(ArrowPythonRunner.scala:122)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:410)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec$anon$2.<init>(ArrowEvalPythonExec.scala:98)
    at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:96)
    at org.apache.spark.sql.execution.python.EvalPythonExec$anonfun$doExecute$1.apply(EvalPythonExec.scala:127)
    at org.apache.spark.sql.execution.python.EvalPythonExec$anonfun$doExecute$1.apply(EvalPythonExec.scala:89)
    at org.apache.spark.rdd.RDD$anonfun$mapPartitions$1$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.RDD$anonfun$mapPartitions$1$anonfun$apply$23.apply(RDD.scala:801)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.run(Task.scala:123)
    at org.apache.spark.executor.Executor$TaskRunner$anonfun$10.apply(Executor.scala:408)