Runtime error when running inference on a compiled nn.TransformerEncoder

When I try to run inference on an nn.TransformerEncoder compiled with torch.compile using PyTorch 2.1.0 and CUDA 12.1 I get the following runtime error:

TorchRuntimeError: Failed running call_module fn(*(FakeTensor(..., device='cuda:0', size=(1, 2, 4)),), **{'src_key_padding_mask': FakeTensor(..., device='cuda:0', size=(1, 2), dtype=torch.bool)}):
meta converter nyi

from user code:
   File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/", line 17, in inner
    return fn(*args, **kwargs)

This is a minimal example code to reproduce the error:

import torch
import torch.nn.functional as F
from torch import nn

encoder_layer = nn.TransformerEncoderLayer(4, 2, 16, batch_first=True, device='cuda')
model = nn.TransformerEncoder(encoder_layer, 2)
model_opt = torch.compile(model)

with torch.inference_mode():
  x = torch.arange(0, .8, .1, device='cuda').reshape((1, 2, 4))
  mask = torch.tensor([[False, True]], device='cuda')
  model_opt(x, src_key_padding_mask=mask)

If I run this code using PyTorch nightly (2.3.0.dev20240228 at the time of writing) I get a different runtime error:

torch._dynamo.exc.TorchRuntimeError: Failed running call_module fn(*(FakeTensor(..., device='cuda:0', size=(1, 2, 4)),), **{'src_key_padding_mask': FakeTensor(..., device='cuda:0', size=(1, 2), dtype=torch.bool)}):
strided nested tensors are not supported by meta conversion

from user code:
   File "/home/nvidia/anaconda3/envs/tfm-david/lib/python3.11/site-packages/torch/_dynamo/", line 25, in inner
    return fn(*args, **kwargs)

The error disappears if I run the model on training mode or set enable_nested_tensor=False when creating the TransformerEncoder.

Am I doing something wrong or is there any way to get around this issue?

This seems to be a known limitation. CC @marksaroufim

1 Like

So nested tensors should work with torch.compile, here’s an example segment-anything-fast/segment_anything_fast/ at main · pytorch-labs/segment-anything-fast · GitHub and segment-anything-fast/experiments/ at main · pytorch-labs/segment-anything-fast · GitHub

There might be some quirk that’s specific to TransformerEncoder and I wouldn’t recommend you use that class and instead leverage SDPA to build out your own Transformer (Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA) — PyTorch Tutorials 2.2.1+cu121 documentation

I’m not coding today so sorry I can’t send a minimal repro @cpuhrsch might be able to share more detail though

1 Like

In the end I got around the issue by using this custom TrasnformerEncoder class:

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int):
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
    def forward(self, src: Tensor, src_key_padding_mask: Tensor) -> Tensor:
        output = src
        use_nested = not self.layers[0].training
        if use_nested:
            output = torch._nested_tensor_from_mask(
            src_key_padding_mask = None

        for mod in self.layers:
            output = mod(output, is_causal=False, src_key_padding_mask=src_key_padding_mask)

        return output.to_padded_tensor(0., src.size()) if use_nested else output

It is the same as nn.TrasnformerEncoder but less parametrized and with all of the fast path checks removed (except checking it the encoder layers are in training mode). In PyTorch 2.1.0 it still produces the same error, but not with PyTorch nightly. I guess the compilation of one of the fast path checks was causing the error.

NestedTensors do work with torch.compile, but it’s still quite limited. We’re working on it intensely.

@davidaf3 - I’m happy to see the custom TransformerEncoder works. Does it also meet your performance requirements?

Yes, I’m mostly doing training but I was worried I was doing something wrong. Still, I got about a 10% increase in inference performance, although my data has little padding.

Also, torch.compile treats differently torch.nn modules and user-defined modules, right? Because I tried copy pasting the nn.TransformerEncoder code into a new class and it worked with nested tensors and torch.compile. I did some dirty print debugging and found out that inside nn.TransformerEncoder’s forward, torch._nested_tensor_from_mask was being called with a FakeTensor, which caused the error. While in my copy pasted class, torch._nested_tensor_from_mask was being called with the actual tensor.