How to create batch identity matrices?

I would like to create a batch of identity matrices to initialize a distributions.MultivariateNormal object. Is there a way to call torch.eye() as a batch operation?

If not, are there any alternatives you would suggest?

1 Like

You can use repeat to repeat the tensor n times
But, will need to convert the tensor to batch format first.

  1. Reshape to 1 x size x size
  2. Use repeat and mention batchsize x 1 x 1

Following is an example

x = torch.eye(3)
x = x.reshape((1, 3, 3))
y = x.repeat(5, 1, 1)
print(x.shape)
print(y.shape)

Output:
>>> print(x.shape)
torch.Size([1, 3, 3])
>>> print(y.shape)
torch.Size([5, 3, 3])

11 Likes