Surprisingly Small Improvement for Transformers

Hey!

Recently, I’ve tried to use torch.compile to accelerate inference for the OCR-free Document Understanding Transformer (Donut) model (SwinTransformer encoder + MBart decoder). However, I’ve seen only a very small improvement: 0.602±0.016 –> 0.596±0.016 (~1%). I’ve tried all possible modes and various configurations, but this has not improved the result.

To check if my setup and versions are okay, I tested resnet18 and got results that matched the benchmarks: 0.004±0.00005 –> 0.0026±0.0001 (~35%).

According to HuggingFace benchmarks, such model should have a much bigger speedup on an A100.

My setup:
• GPU: A100 40GB
• CUDA: 12.0
• torch: 2.3.0
• triton: 2.3.0

What do you think the problem could be?

This is code for donut:

import torch
from time import perf_counter
import numpy as np
from PIL import Image

from datasets import load_dataset
from transformers import DonutProcessor, VisionEncoderDecoderModel


def main() -> None:
    processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")
    model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-rvlcdip")

    model.half()
    model.to(torch.device("cuda"))
    model.eval()
    model = torch.compile(mode)

    dataset = load_dataset("hf-internal-testing/example-documents", split="test")
    image = dataset[1]["image"]

    times = []
    for _ in range(150):
        t1 = perf_counter()
        with torch.no_grad():
            task_prompt = "<s_rvlcdip>"
            decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
            pixel_values = processor(images, return_tensors="pt").pixel_values
            _ = model.generate(
                pixel_values.to(model.encoder.device, model.encoder.dtype),
                decoder_input_ids=decoder_input_ids.to(model.decoder.device),
                max_length=model.decoder.config.max_position_embeddings,
                pad_token_id=processor.tokenizer.pad_token_id,
                eos_token_id=processor.tokenizer.eos_token_id,
                use_cache=True,
                bad_words_ids=[[processor.tokenizer.unk_token_id]],
                return_dict_in_generate=True,
            )
        t2 = perf_counter()
        times.append(t2 - t1)

    print(np.mean(times[50:]), np.std(times[50:]))


if __name__ == "__main__":
    main()

This is code for resnet:

import torch
from time import perf_counter
import numpy as np


def main() -> None:
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)

    model.half()
    model.to(torch.device("cuda"))
    model.eval()
    model = torch.compile(model)

    times = []
    for _ in range(150):
       t1 = perf_counter()
       with torch.no_grad():
           _ = model(torch.randn(2, 3, 128, 128).half().to(torch.device("cuda")))
       t2 = perf_counter()
       times.append(t2 - t1)

    print(np.mean(times[50:]), np.std(times[50:]))


if __name__ == "__main__":
    main()

Also, I checked it for ViT as it was used in HuggingFace benchmarks. It is even better, in benchmarks it’s 18% (0.009325s → 0.007584s), I’ve got 30% improvement (0.00767s → 0.005378s). So, the problem is in my model, but I don’t understand why, because it looks like it consists of only well-known submodels (SwinTransformer encoder, MBart decoder).

Code:

import torch
from PIL import Image
import requests
import numpy as np
from time import perf_counter

from transformers import AutoImageProcessor, AutoModelForImageClassification


def main() -> None:
    url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
    image = Image.open(requests.get(url, stream=True).raw)

    processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
    model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224").to("cuda")
    model = torch.compile(model)

    processed_input = processor(image, return_tensors='pt').to(device="cuda")

    times = []
    for _ in range(150):
        t1 = perf_counter()
        with torch.no_grad():
             _ = model(**processed_input)
        t2 = perf_counter()
        times.append(t2 - t1)

    print(np.mean(times[50:]), np.std(times[50:]))


if __name__ == "__main__":
    main()

Your profiling is invalid since you are not synchronizing the code before starting and stopping the host timers. Since CUDA operations are executed asynchronously you would need to synchronize the code before starting and stopping the host timers.

1 Like

@ptrblck

I tried the following methods (with 50 steps as a warmup):

start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
...
end.record()
torch.cuda.synchronize()

delta = start.elapsed_time(end)
torch.cuda.synchronize()
t1 = perf_counter()
...
torch.cuda.synchronize()
t2 = perf_counter()

delta = t2 - t1

However, I’ve got the similar results, torch.compile doesn’t give any speed up.
Could you please advise me what I should try next?

I found the problem.

In some cases, torch.compile cannot understand that a model consists of several submodels (computational graphs) and perhaps tries to optimize the whole model as one computational graph. So, submodels should be compiled separately:

model.encoder = torch.compile(model.encoder)
model.decoder = torch.compile(model.decoder)