Hi Pumplerod!
Use broadcasting with an unsqueeze()
to get the dimensions lined
up properly:
>>> import torch
>>> torch.__version__
'1.10.2'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> A = torch.randn (3, 5)
>>> B = torch.randn (3, 6, 5)
>>>
>>> resultBroadcast = A.unsqueeze (1) - B
>>> resultBroadcast.shape
torch.Size([3, 6, 5])
But if you really want to use einsum()
you can just convert addition
into multiplication:
>>> resultEinsum = torch.einsum('ij,ikj->ikj', A.exp(), (-B).exp()).log()
>>> torch.allclose (resultBroadcast, resultEinsum)
True
Best.
K. Frank