Fusion of operations when lots of redundant computations are present?

I have a very strange task that requires pairs of inputs be passed through a neural network, where the pairs create many-to-many relationship, and each “many” is really “very many”. This means that for each input tensor from set 1, there will be lots of redundant calculations made for each comparison made with each tensor from set 2. I believe one could do operation fusion to speed this up - fusion of a repetitive input and a similar layer seems conceptually similar to fusion of a conv and BN layer, however i cannot seem to find a good way to do this. Most of the ideas I came up with are stupid, and do not actually do fusion.

Does anyone know of a good way to do this - or - if it is even possible natively in pytorch?

Additionally, do people believe my hypothesis is correct, and this would boost performance measurably?