Extremely slow inference speed with dynamic inference graph

I am trying to train a network, but training different portion of the whole network in each iteration. Let’s say 1x represents the whole network, 0.5x means half of the network(just use the first 0.5x weights of the 1x network). I randomly sample a width in each iteration, and do forward and backward only with that part of the whole network. But I found the forward and backward speed is very slow. However, if I sample width from a fixed set, e.g. [0.25, 0.5, 0.75, 1.0], the speed will be much faster. Is this caused by the dynamic computational graph in pytorch? But I think even I fix the width_set, it still need to create a different graph in every iteration. I don’t know why there is a different between these two situations.