How to add an activation function to the intermediate result of torch.matmul

Hi Jun!

If I understand your question correctly, you wish to apply some function, f(),
to each scalar product of the form m * n in your matrix multiplication before
you add them together.

That is, for A with shape [3, 2] and B of shape [2, 3] your product matrix
A @ B will have shape [3, 3], where each element of A @ B will be a sum of
two terms and you want to apply f() to each of those eighteen terms individually
before adding them together.

I would recommend using einsum() to perform the scalar multiplications while
“deferring” the additions, apply f() element-wise to the tensor resulting from
einsum(), and then sum the results.

Consider:

>>> import torch
>>> torch.__version__
'2.3.0'
>>> _ = torch.manual_seed (2024)
>>> a = torch.randn (3, 2)
>>> b = torch.randn (2, 3)
>>> a
tensor([[-0.0404,  1.7260],
        [-0.8140,  1.3722],
        [ 0.5060, -0.4823]])
>>> b
tensor([[-0.7853,  0.6681, -0.4439],
        [ 0.1888,  0.5986,  0.6458]])
>>> a @ b
tensor([[ 0.3577,  1.0062,  1.1326],
        [ 0.8983,  0.2775,  1.2476],
        [-0.4884,  0.0493, -0.5361]])
>>> torch.einsum ('ij, jk -> ik', a, b)
tensor([[ 0.3577,  1.0062,  1.1326],
        [ 0.8983,  0.2775,  1.2476],
        [-0.4884,  0.0493, -0.5361]])
>>> torch.einsum ('ij, jk -> ikj', a, b)
tensor([[[ 0.0317,  0.3259],
         [-0.0270,  1.0332],
         [ 0.0179,  1.1147]],

        [[ 0.6392,  0.2591],
         [-0.5438,  0.8214],
         [ 0.3614,  0.8862]],

        [[-0.3973, -0.0911],
         [ 0.3380, -0.2887],
         [-0.2246, -0.3115]]])
>>> torch.einsum ('ij, jk -> ikj', a, b).sum (-1)
tensor([[ 0.3577,  1.0062,  1.1326],
        [ 0.8983,  0.2775,  1.2476],
        [-0.4884,  0.0493, -0.5361]])
>>> torch.einsum ('ij, jk -> ikj', a, b).sigmoid().sum (-1)
tensor([[1.0887, 1.2308, 1.2575],
        [1.2190, 1.0618, 1.2975],
        [0.8792, 1.0120, 0.8668]])

Best.

K. Frank

1 Like

Dear Frank:
Thank you very much for your detailed help, I have solved the problem in my project.

Best wishes from China.
Liu Jun