Hello everyone, hope you are all having a great time.
I was recently thinking about module fusion and since we have many examples of fused operators in torch I tried a simple example to see how much of an improvement can be expected even in a simple case, I noticed not much really, and I even got worse results.
As an example, I tried
key = nn.Linear(5,5, bias=False)
query = nn.Linear(5,5, bias=False)
value = nn.Linear(5,5, bias=False)
x = torch.randn(size = (3,2,5))
k,q,v = [m(x) for m in (key, query, value)]
kqv = nn.Linear(5,15, bias=False)
kqv_output = kqv(x)
k2,q2,v2 = torch.split(kqv_output, 5, dim=-1)
for large numbers, the unfused version works better, for smaller sizes, the second one seems better.
so what is the culprit here? is it the abstraction thats causing the second method to perorm worse? or are there any rules of thumb when it comes to fusion in general or in pytorch?
Thanks a lot in advance