How can I optimize this calculation?

I have the following toy code:

-modeField = torch.randn([10,64,64])
-paras = torch.randn([batchSize, 10])
-for bat in batchSize:
-----result = 0
-----for i in range(10):
--------res += modeField[i] * paras[bat][i]
-----loss = (res**2 - speckle)**2

This is taking really long. Is there any way, to do some batch multiplication to get the imediate result for all batches and paras, instead of doing it with for loops?

I’ve looked into matmul and bmm, but I dont really understand them and it seems like they dont multiplicate a whole array with only a single value.

I guess I could map the parameter value on an array of the same size as modeField but that would’nt be a very elegant solution.

import torch 

batchSize = 3
modeFiled = torch.randn([10, 64, 64])
paras = torch.randn([batchSize, 10])

# your method
your_method_result = []
for bat in range(batchSize):
    res = 0 
    for i in range(10):
        res += modeField[i] * paras[bat][i]
    your_method_result.append(res)
your_method_result = torch.stack(your_method_result, dim=0)

# faster implementation
result = torch.matmul(paras, modeField.view(10, -1)).view(batchSize, 64, 64)

# show differents, it should be 1e-6
print((your_method_result - result).abs().max())
2 Likes

It works perfectly. Thank you.