I have 2D tensor of matrices, i.e I have matrices_tensor in size NxNx3x3 such that for index (i,j), matrices_tensor[i,j,…] yields matrix in size of 3x3.
I would like to do the following:
For target Nx3x1, I would like to get output in size Nx3 such that
output = torch.zeros((N,3))
for i in range(N):
for j in range(N):
output[i,:] += matrices_tensor[i,j] @ target[j,:]
where @ is matrix multiplication operator.
I want to do in tensor operation and not in loop in order to leverage the performance in gpu.
The matrix should be sparse, but for now, I’m looking for solution of dense matrix.
It’s not necessary to know the application but if someone wants to know, the purpose is optimizing non rigid transformation for mesh deformation and regularize with smoothness of the rotated vertices.
torch.einsum() is your go-to tool when you want to sum multiplied slices
(“contract indices”) of tensors.
Here is an example script:
_ = torch.manual_seed (2021)
N = 5
m = 3
matrices_tensor = torch.randn (N, N, m, m)
target = torch.randn (N, m, 1)
output_einsum = torch.einsum ('ijkl, jlm -> ik', matrices_tensor, target)
output = torch.zeros((N, m))
for i in range (N):
for j in range (N):
output[i, :] += (matrices_tensor[i, j] @ target[j, :]).squeeze()
print ('output_einsum.shape =', output_einsum.shape)
print ('output.shape =', output.shape)
print (torch.allclose (output_einsum, output))
And here is its output:
output_einsum.shape = torch.Size([5, 3])
output.shape = torch.Size([5, 3])
(Note that your
target has a trailing singleton dimension, while your
output does not so I had to
squeeze() the result of the matrix
multiplication to make your code work.)