I am able to run the following code on a PySpark Dataframe containing a column of text:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelWithLMHead.from_pretrained("gpt2")
str2tokens = udf(lambda s: tokenizer.tokenize(s, return_tensors='pt'), ArrayType(StringType()))
tokens2ids = udf(lambda s: tokenizer.convert_tokens_to_ids(s), ArrayType(IntegerType()))
So I am able to load PyTorch and Huggingface libraries, at least. The problem comes when I try to do inference and write out the logits using the following udf:
vectorize_ids = udf(lambda s: model(tensor(s)).logits.mean(axis=0).detach().numpy(), ArrayType(DoubleType()))
I am getting Java heap space errors even running on tiny file sizes:
21/01/15 13:17:00 WARN Utils: Your hostname, Evans-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.7.76 instead (on interface en0)
21/01/15 13:17:00 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
21/01/15 13:17:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/01/15 13:17:32 WARN Utils: Suppressing exception in finally: Java heap space
java.lang.OutOfMemoryError: Java heap space
at java.nio.HeapByteBuffer.<init>(HeapByteBuffer.java:57)
at java.nio.ByteBuffer.allocate(ByteBuffer.java:335)
at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$3.apply(TorrentBroadcast.scala:286)
at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$3.apply(TorrentBroadcast.scala:286)
at org.apache.spark.util.io.ChunkedByteBufferOutputStream.allocateNewChunkIfNeeded(ChunkedByteBufferOutputStream.scala:87)
at org.apache.spark.util.io.ChunkedByteBufferOutputStream.write(ChunkedByteBufferOutputStream.scala:75)
at net.jpountz.lz4.LZ4BlockOutputStream.flushBufferedData(LZ4BlockOutputStream.java:220)
at net.jpountz.lz4.LZ4BlockOutputStream.finish(LZ4BlockOutputStream.java:252)
at net.jpountz.lz4.LZ4BlockOutputStream.close(LZ4BlockOutputStream.java:190)
at java.io.ObjectOutputStream$BlockDataOutputStream.close(ObjectOutputStream.java:1828)
at java.io.ObjectOutputStream.close(ObjectOutputStream.java:742)
at org.apache.spark.serializer.JavaSerializationStream.close(JavaSerializer.scala:57)
at org.apache.spark.broadcast.TorrentBroadcast$$anonfun$blockifyObject$1.apply$mcV$sp(TorrentBroadcast.scala:293)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1369)
at org.apache.spark.broadcast.TorrentBroadcast$.blockifyObject(TorrentBroadcast.scala:292)
at org.apache.spark.broadcast.TorrentBroadcast.writeBlocks(TorrentBroadcast.scala:127)
at org.apache.spark.broadcast.TorrentBroadcast.<init>(TorrentBroadcast.scala:88)
at org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast(TorrentBroadcastFactory.scala:34)
at org.apache.spark.broadcast.BroadcastManager.newBroadcast(BroadcastManager.scala:62)
at org.apache.spark.SparkContext.broadcast(SparkContext.scala:1489)
at org.apache.spark.api.java.JavaSparkContext.broadcast(JavaSparkContext.scala:650)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Any idea if there’s something fundamentally wrong with this approach? I have tried using bigger instances on Amazon EMR with no luck as well.