I’m in a situation where compiling does not work for the whole model, so I’m wrapping each layer with torch.compile
. Because I have (a fixed number of) variable shapes there is a significant amount of compilation time.
Each layer is just a transformer layer, I’m wondering if the compilation has to be redone for each layer or is there a way to share that across layers? So the compilation only happens when the input shape changes for the first layer (and the rest inherit the new kernel).