Many AsStridedBackward0 during the backward

Our model needs to split the input data. Formatting input data within the forward function of our model causes extremely poor performance when training with torch.compile.

torch: v2.1.0
GPU: V100

def CustomModel(torch.nn.Module):
    def __init__(self) -> None:
        ...

    def format_data(self, data):
        data = torch.reshape(data, [-1, 5, 121024])
        data_list = torch.permute(data, [1, 0, 2])
        data_split_shape = [1,1,1,1,1,950,14,5,1,1,1,1,1,1,5,5,5,5,1,5,5,1,1,1,1,1,14,25,5,1,1,1,1,1,1,1,1,1,5,25,42,42,42,42,3,61,3,61]
        each_data_length = sum(data_split_shape)
        sequence_data_length = each_data_length * 16
        sequence_data_split_shape = [sequence_data_length, 4096, 256, 4096, 256]
        sequence_data, data1, data2, data3, data4 = torch.split(data_list, sequence_data_split_shape, dim=-1)
        sequence_data = torch.reshape(sequence_data, [5, -1, each_data_length])
        each_data_list = [torch.split(sequence_data[i], data_split_shape, dim=-1) for i in range(5)]
        d1 = data1.reshape([-1, 4096])
        d3 = data3.reshape([-1, 4096])
        d2 = data2.reshape([-1, 256])
        d4 = data4.reshape([-1, 256])
        return each_data_list, (d2, d1), (d4, d3)

    def forward(self, data):
        data_list = self.format_data(data)
        return ...

model = CustomModel().to("cuda")
model.train()

compiled_model = torch.compile(model)

total_step = 10000
for step in range(total_step):
    res = compiled_model(data)
    loss = compiled_model.loss_def(res)
    loss.backward()

The latency comparison of each step during training is as follows

batch_size format_in_forward not_in_forward
16 601.5 ms 781.4ms

Through torch.profile, we found a large number of AsStridedBackward0 during backward.

I found this comment while reading the source code. Is it related to the problem I am facing? How should I solve this problem? Should I only move format_data out of forward?

Hey @pansn do you have a self-contained code repro I can use that replicates the issue? That would be helpful for tracking down the as_strider calls.

@bdhirsh Sorry, maybe my previous code was too abstract. Here is a simple reproducible sample code. Could you please take a look and help me out?

import torch
import numpy as np
import time

class CustomModel(torch.nn.Module):
    def format_data(self, data):
        data = torch.reshape(data, [-1, 5, 31024])
        data_list = torch.permute(data, [1, 0, 2])
        data_split_shape = [1,1,1,1,1,950,14,5,1,1,1,1,1,1,5,5,5,5,1,5,5,1,1,1,1,1,14,25,5,1,1,1,1,1,1,1,1,1,5,25,42,42,42,42,3,61,3,61]
        each_data_length = sum(data_split_shape)
        sequence_data_length = each_data_length * 16
        sequence_data_split_shape = [sequence_data_length, 4096, 256, 4096, 256]
        sequence_data, data1, data2, data3, data4 = torch.split(data_list, sequence_data_split_shape, dim=-1)
        sequence_data = torch.reshape(sequence_data, [5, -1, each_data_length])
        each_data_list = [torch.split(sequence_data[i], data_split_shape, dim=-1) for i in range(5)]
        d1 = data1.reshape([-1, 4096])
        d3 = data3.reshape([-1, 4096])
        d2 = data2.reshape([-1, 256])
        d4 = data4.reshape([-1, 256])
        return each_data_list, (d2, d1), (d4, d3)

    def forward(self, data):
        return self.format_data(data)
    
    def loss_def(self, res1, res2, res3):
        loss = 0
        for r in res1:
            for x in r:
                loss = loss + torch.sum(x)
        for r in res2:
            loss = loss + torch.sum(r)
        for r in res3:
            loss = loss + torch.sum(r)
        return loss

device = "cuda"

model = CustomModel().to(device)
model.train()

dtype = torch.float32

compiled_model = torch.compile(model)
warmup_step = 10
total_step = 100

for step in range(total_step):
    if step == warmup_step:
        torch.cuda.synchronize(device=device)
        start_time = time.perf_counter()
    data = torch.tensor(np.random.random([16, 5*31024]) + 0.1, requires_grad=True).type(dtype).to(device)
    res1, res2, res3 = compiled_model(data)
    loss = compiled_model.loss_def(res1, res2, res3 )
    loss.backward()

torch.cuda.synchronize(device=device)
end_time = time.perf_counter()

print(
    "torch.compile avg step time: {} ms".format(
        (end_time - start_time) * 1e3 / (total_step - warmup_step)
    )
)

for step in range(total_step):
    if step == warmup_step:
        torch.cuda.synchronize(device=device)
        start_time = time.perf_counter()
    data = torch.tensor(np.random.random([16, 5*31024]) + 0.1, requires_grad=True).type(dtype).to(device)
    res1, res2, res3 = model(data)
    loss = model.loss_def(res1, res2, res3 )
    loss.backward()

torch.cuda.synchronize(device=device)
end_time = time.perf_counter()

print(
    "torch avg step time: {} ms".format(
        (end_time - start_time) * 1e3 / (total_step - warmup_step)
    )
)

V100 result

torch.compile avg step time: 49.37022762993971 ms
torch avg step time: 35.64678467810154 ms

Hey @pansn - I believe the AsStridedBackward slowness should be fixed by https://github.com/pytorch/pytorch/pull/111411. I tried running your repro on top of that fix on a nightly, and when I print out the grad_fn of every output of the forward, I no longer see any AsStridedBackward nodes.