Pinverse of a batch of 2d matrices

pinverse happily takes batches:

a = torch.randn(5,6,3)
b = torch.pinverse(a)
c = torch.stack([torch.pinverse(a_i) for a_i in a], 0)
print((b-c).abs().max())

gives 5e-8 ish discrepancy.

As it goes with Linear Algebra, the invariably awesome @vishwakftw implemented it just in time for the 1.3 release.

As a general rule, the nicely named linear algebra functions do batches.

Best regards

Thomas

1 Like