Our model needs to split the input data. Formatting input data within the forward function of our model causes extremely poor performance when training with torch.compile.
torch: v2.1.0
GPU: V100
def CustomModel(torch.nn.Module):
def __init__(self) -> None:
...
def format_data(self, data):
data = torch.reshape(data, [-1, 5, 121024])
data_list = torch.permute(data, [1, 0, 2])
data_split_shape = [1,1,1,1,1,950,14,5,1,1,1,1,1,1,5,5,5,5,1,5,5,1,1,1,1,1,14,25,5,1,1,1,1,1,1,1,1,1,5,25,42,42,42,42,3,61,3,61]
each_data_length = sum(data_split_shape)
sequence_data_length = each_data_length * 16
sequence_data_split_shape = [sequence_data_length, 4096, 256, 4096, 256]
sequence_data, data1, data2, data3, data4 = torch.split(data_list, sequence_data_split_shape, dim=-1)
sequence_data = torch.reshape(sequence_data, [5, -1, each_data_length])
each_data_list = [torch.split(sequence_data[i], data_split_shape, dim=-1) for i in range(5)]
d1 = data1.reshape([-1, 4096])
d3 = data3.reshape([-1, 4096])
d2 = data2.reshape([-1, 256])
d4 = data4.reshape([-1, 256])
return each_data_list, (d2, d1), (d4, d3)
def forward(self, data):
data_list = self.format_data(data)
return ...
model = CustomModel().to("cuda")
model.train()
compiled_model = torch.compile(model)
total_step = 10000
for step in range(total_step):
res = compiled_model(data)
loss = compiled_model.loss_def(res)
loss.backward()
The latency comparison of each step during training is as follows
batch_size | format_in_forward | not_in_forward |
---|---|---|
16 | 601.5 ms | 781.4ms |
Through torch.profile, we found a large number of AsStridedBackward0 during backward.