Is it safe to do `torch.matmul(M, X, out=X)`?

I would like to do X = M @ X, but without allocating an extra matrix on the RHS.
So I tried torch.matmul(M, X, out=X) and it seems to work.

I can even do torch.matmul(X, X, out=X) and the results seem to come out right:

[ins] In [53]: x = torch.randn(4,4)

[ins] In [54]: torch.matmul(x, x)
Out[54]:
tensor([[ 7.7674,  2.4361,  1.5354, -1.0358],
        [ 4.4382, -1.2255, -0.2265,  0.7528],
        [-4.6821,  0.8569, -4.4141, -0.9833],
        [-1.8084,  6.9762,  2.8231,  1.0764]])

[ins] In [55]: torch.matmul(x, x, out=x)
Out[55]:
tensor([[ 7.7674,  2.4361,  1.5354, -1.0358],
        [ 4.4382, -1.2255, -0.2265,  0.7528],
        [-4.6821,  0.8569, -4.4141, -0.9833],
        [-1.8084,  6.9762,  2.8231,  1.0764]])

My question is whether it’s safe to rely on this behaviour? Or if I need to allocate a temporary tensor for the results, then copy them to X?

Generally this should be safe (if it wasn’t the op should complain about aliased pointers).

Well, it doesn’t work for bmm, and there are no complaints in that case:

[ins] In [57]: x = torch.randn(2,2,2)

[ins] In [58]: torch.bmm(x,x)
Out[58]:
tensor([[[ 0.2954,  0.6422],
         [-0.3119,  1.9941]],

        [[ 0.5595,  0.0932],
         [-0.4605,  1.6915]]])

[ins] In [59]: torch.bmm(x,x,out=x)
Out[59]:
tensor([[[-0.3908,  0.0000],
         [ 0.0000,  0.0000]],

        [[-0.1640,  0.0000],
         [ 0.0000,  0.0000]]])

Are you saying this should have raised an error?

This should never be attempted as most likely it will produce wrong results. We are not checking that because that’s really a very niche case, and correctly writing a check that would be fast and accurate is hard, but it’s a very bad idea.

I’m not proposing using this to square matrices, I’m just trying to figure out if out= actually works at all, or if it just allocates an internal tensor and copies it to out= when it’s done.

Or are you suggesting to just never use the out= parameter?

Another weird thing: It seems that out=temp also works when temp has the wrong shape. This furthers my suspicion that maybe out= doesn’t actually work?