Hi Jun!
If I understand your question correctly, you wish to apply some function, f()
,
to each scalar product of the form m * n
in your matrix multiplication before
you add them together.
That is, for A
with shape [3, 2]
and B
of shape [2, 3]
your product matrix
A @ B
will have shape [3, 3]
, where each element of A @ B
will be a sum of
two terms and you want to apply f()
to each of those eighteen terms individually
before adding them together.
I would recommend using einsum()
to perform the scalar multiplications while
“deferring” the additions, apply f()
element-wise to the tensor resulting from
einsum()
, and then sum the results.
Consider:
>>> import torch
>>> torch.__version__
'2.3.0'
>>> _ = torch.manual_seed (2024)
>>> a = torch.randn (3, 2)
>>> b = torch.randn (2, 3)
>>> a
tensor([[-0.0404, 1.7260],
[-0.8140, 1.3722],
[ 0.5060, -0.4823]])
>>> b
tensor([[-0.7853, 0.6681, -0.4439],
[ 0.1888, 0.5986, 0.6458]])
>>> a @ b
tensor([[ 0.3577, 1.0062, 1.1326],
[ 0.8983, 0.2775, 1.2476],
[-0.4884, 0.0493, -0.5361]])
>>> torch.einsum ('ij, jk -> ik', a, b)
tensor([[ 0.3577, 1.0062, 1.1326],
[ 0.8983, 0.2775, 1.2476],
[-0.4884, 0.0493, -0.5361]])
>>> torch.einsum ('ij, jk -> ikj', a, b)
tensor([[[ 0.0317, 0.3259],
[-0.0270, 1.0332],
[ 0.0179, 1.1147]],
[[ 0.6392, 0.2591],
[-0.5438, 0.8214],
[ 0.3614, 0.8862]],
[[-0.3973, -0.0911],
[ 0.3380, -0.2887],
[-0.2246, -0.3115]]])
>>> torch.einsum ('ij, jk -> ikj', a, b).sum (-1)
tensor([[ 0.3577, 1.0062, 1.1326],
[ 0.8983, 0.2775, 1.2476],
[-0.4884, 0.0493, -0.5361]])
>>> torch.einsum ('ij, jk -> ikj', a, b).sigmoid().sum (-1)
tensor([[1.0887, 1.2308, 1.2575],
[1.2190, 1.0618, 1.2975],
[0.8792, 1.0120, 0.8668]])
Best.
K. Frank
Dear Frank:
Thank you very much for your detailed help, I have solved the problem in my project.
Best wishes from China.
Liu Jun