I have the following toy code:
-modeField = torch.randn([10,64,64])
-paras = torch.randn([batchSize, 10])
-for bat in batchSize:
-----result = 0
-----for i in range(10):
--------res += modeField[i] * paras[bat][i]
-----loss = (res**2 - speckle)**2
This is taking really long. Is there any way, to do some batch multiplication to get the imediate result for all batches and paras, instead of doing it with for loops?
I’ve looked into matmul and bmm, but I dont really understand them and it seems like they dont multiplicate a whole array with only a single value.
I guess I could map the parameter value on an array of the same size as modeField but that would’nt be a very elegant solution.