Recompiles on loss_kwargs changes

I have a VAE, for which the forward method takes a kl_weight as input. We warm up this kl, if I pass this changing kl_weight to the forward method of a compiled pytorch module, it gets extremely slow, whereas when I compute the loss by changing the kl_weight outside of the compiled model and pass the changed loss to pytorch lightning I get a significant speedup by compilation. Am I doing something wrong? For reference, this is the model I’m trying to optimize: Comparing main...can_compile · scverse/scvi-tools · GitHub

I figured it out. Providing kl_weight as torch.tensor instead of float works without recompiling.

1 Like