Batch Scalar multiplication

I have a matrix A with shape (N, 1) and a matrix B with shape (2, 2). I want that each entry in the A matrix (column vector) is multiplied with the B matrix (each component will be a value so scalar multiplication of that value with the B matrix) to get a matrix with shape (N, 2, 2) where each matrix along the first dimension will be the resultant scalar multiplied matrix.

I am really confused how should the broadcasting should work to get the desired result.

Thanks in advance !!

This way:

import torch

n = 3
a = torch.arange(n).unsqueeze(1) # [n, 1]
b = torch.ones(4).view(2, 2) # [2, 2]

a[:, :, None] * b # [n, 2, 2]
2 Likes

Thanks a lot :smiley: it worked like a charm !!

1 Like