With pytorch 2.10 on CPU, the following
import os
import sys
import torch
from torch.nn.utils.rnn import pad_sequence
def plain_call(seqs):
out = pad_sequence(seqs, batch_first=True, padding_value=0, padding_side="left")
def compiled_call(seqs):
@torch.compile(fullgraph=True)
def f(xs):
return pad_sequence(xs, batch_first=True, padding_value=0, padding_side="left")
return f(seqs)
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
seqs = [a, b]
print("Uncompiled pad_sequence:")
plain_call(seqs)
print("Compiled pad_sequence:")
try:
out_compiled = compiled_call(seqs)
print(out_compiled)
print("\n[OK] compiled path succeeded")
except Exception as e:
print("[ERROR] compiled path raised:")
print(type(e).__name__, e)
gives:
* Eager path: OK
* torch.compile(fullgraph=False): OK
* torch.compile(fullgraph=True): raises TypeError ... pad_sequence() takes from 1 to 3 positional arguments but 4 were given
Thanks for any help.