I have an experiment setting where I have a different batch size at each iteration during the training. I am using torch.compile() for my model. Whenever the model sees new batch size, it re-compiles and thus the whole process becomes extremely slow.
Is it possible to have a compiled model, which doesn’t re-compile every time it sees new batch size?