How to create batch identity matrices?

You can use repeat to repeat the tensor n times
But, will need to convert the tensor to batch format first.

  1. Reshape to 1 x size x size
  2. Use repeat and mention batchsize x 1 x 1

Following is an example

x = torch.eye(3)
x = x.reshape((1, 3, 3))
y = x.repeat(5, 1, 1)
print(x.shape)
print(y.shape)

Output:
>>> print(x.shape)
torch.Size([1, 3, 3])
>>> print(y.shape)
torch.Size([5, 3, 3])

11 Likes