I have a relatively simple architecture for a graph neural net where I first process an input to a hidden dimension and then pass it and other inputs preprocessed in the same way into a MultiheadAttention layer. However, I noticed that at inference time all of my attention scores are equal, which isn’t the case during training.
Digging into this its cause by all the dot products between projected features being 0. Is there a reason this might be happening?
It seems like if I don’t call
model.eval() the attention scores become more similar to what I see during training. Maybe this has to do with the fast path optimizations torch is doing?