Mixed dimension difference operation with einsum

Hi Pumplerod!

Use broadcasting with an unsqueeze() to get the dimensions lined
up properly:

>>> import torch
>>> torch.__version__
'1.10.2'
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> A = torch.randn (3, 5)
>>> B = torch.randn (3, 6, 5)
>>>
>>> resultBroadcast = A.unsqueeze (1) - B
>>> resultBroadcast.shape
torch.Size([3, 6, 5])

But if you really want to use einsum() you can just convert addition
into multiplication:

>>> resultEinsum = torch.einsum('ij,ikj->ikj', A.exp(), (-B).exp()).log()
>>> torch.allclose (resultBroadcast, resultEinsum)
True

Best.

K. Frank

1 Like