Torch Compile Llama Slow Down - Unnecessary Copies

Hi there,

I was trying to recreate some of the speed ups for Llama models using torch compile like i have seen in a few blog posts, here and here.

I see speed ups for small batch size and sequence lengths but when I get to large batch sizes and sequence lengths it slows down compared to torch eager.

I looked at the profiler and saw a large amount of time spent copying data after the model was finished executing:

In, these copies are at the end:

copy_: "bf16[s0, s1, 4, 64]" = torch.ops.aten.copy_.default(arg204_1, slice_scatter_2);  arg204_1 = slice_scatter_2 = None
copy__1: "bf16[s0, s1, 4, 64]" = torch.ops.aten.copy_.default(arg205_1, slice_scatter_5);  arg205_1 = slice_scatter_5 = None
copy__2: "bf16[s0, s1, 4, 64]" = torch.ops.aten.copy_.default(arg206_1, slice_scatter_8);  arg206_1 = slice_scatter_8 = None
copy__3: "bf16[s0, s1, 4, 64]" = torch.ops.aten.copy_.default(arg207_1, slice_scatter_11);  arg207_1 = slice_scatter_11 = None
copy__4: "bf16[s0, s1, 4, 64]" = torch.ops.aten.copy_.default(arg208_1, slice_scatter_14);  arg208_1 = slice_scatter_14 = None
copy__5: "bf16[s0, s1, 4, 64]" = torch.ops.aten.copy_.default(arg209_1, slice_scatter_17);  arg209_1 = slice_scatter_17 = None
copy__6: "bf16[s0, s1, 4, 64]" = torch.ops.aten.copy_.default(arg210_1, slice_scatter_20);  arg210_1 = slice_scatter_20 = None

And it looks like they have something to do with the kv-cache tensors.

When we add the newly generated k and v states to the kv-cache torch compile generates this code:

# File: /home/, code: k_cache[:bsz, current_seq_pos, :, :] = key_states
slice_5: bf16[s0, s1, 4, 64] = torch.ops.aten.slice.Tensor(arg206_1, 2, 0, 9223372036854775807)
slice_6: bf16[s0, s1, 4, 64] = torch.ops.aten.slice.Tensor(slice_5, 3, 0, 9223372036854775807);  slice_5 = None
index_put: bf16[s0, s1, 4, 64] = torch.ops.aten.index_put.default(slice_6, [None, arg250_1], permute_6);  slice_6 = permute_6 = None
slice_7: bf16[s0, s1, 4, 64] = torch.ops.aten.slice.Tensor(arg206_1, 2, 0, 9223372036854775807)
slice_scatter: bf16[s0, s1, 4, 64] = torch.ops.aten.slice_scatter.default(slice_7, index_put, 3, 0, 9223372036854775807);  slice_7 = index_put = None
slice_scatter_1: bf16[s0, s1, 4, 64] = torch.ops.aten.slice_scatter.default(arg206_1, slice_scatter, 2, 0, 9223372036854775807);  slice_scatter = None

slice_scatter_1 (and all following slice scatters) are the tensors at the end getting copied back into the original kv cache buffers (arg206_1, arg207_1, etc).

It looks like adding the values from key_states into the slice k_cache[:bsz, current_seq_pos, :, :] doesn’t do this in place, but creates a separate buffer, the contents of which have to be copied back into k_cache in order to be used again for the next iteration.

Attempts to make this operation happen in place using k_cache[slice].copy_(...) didn’t change the result.

Is there a way to get torch.compile to see that this is in an place operation and edit the buffer directly, so that slice_scatter_1 is replaced with arg206_1 and we can avoid these copies? It looks like torch eager does this correctly according to its profile.

Thanks in advance!

Do you have a self contained repro of the code you’re using that shows the slowdown?

Also - can you try running on a nightly and see any improved perf? (there have been some recent changes to inductor’s reinplacing logic that may or may not make this faster)