Sharing torch compile kernels between layers

I’m in a situation where compiling does not work for the whole model, so I’m wrapping each layer with torch.compile. Because I have (a fixed number of) variable shapes there is a significant amount of compilation time.

Each layer is just a transformer layer, I’m wondering if the compilation has to be redone for each layer or is there a way to share that across layers? So the compilation only happens when the input shape changes for the first layer (and the rest inherit the new kernel).

@ezyang @ptrblck

Caching should already be used and you could check this doc for more information.

1 Like