Re-writing Torch.inverse for mobile Conversion support

Hey Guys,

My model seems to be function fine but used inverse and when transferring it over to mobile that function wasn’t available. We check the model working fine without this function so we have re-written it below. I think the functions are taking the data in and out of tensors which is the things causing the errors and model to return 0. I’m happy to post this function once complete for others to use, just want to get it working. I’m newer to torch, but I think someone could tell me quick how to keep this in the same tensor and work it through.

import torch

def transpose(m):
    return map(list,zip(*m))

def minor(m,i,j):
    return [row[:j] + row[j+1:] for row in (m[:i]+m[i+1:])]

def determinant(m):
    if m.shape[0] == 2:
        return m[0][0]*m[1][1]-m[0][1]*m[1][0]
    det = 0
    for c in m.range(m.shape[0]):
        det += ((-1)**c)*m[0][c]*determinant(minor(m,0,c))
    return det

def inverse(m):
    det = determinant(m)
    if m.shape[0] == 2:
        return [[m[1][1]/det, -1*m[0][1]/det],
                [-1*m[1][0]/det, m[0][0]/det]]
    cofactors = []
    for r in torch.range(m.shape[0]):
        cofactorRow = []
        for c in  torch.range(m.shape[0]):
            mi = minor(m,r,c)
            cofactorRow.append(((-1)**(r+c)) * determinant(mi))
        cofactors.append(cofactorRow)
    cofactors = transpose(cofactors)
    for r in  torch.range(m.shape[0]):
        for c in  torch.range(m.shape[0]):
            cofactors[r][c] /= det
    return cofactors
    
def tensor_inverse(m):
    arr = []
    sqz = torch.squeeze(m,axis=0) 
    for mm in sqz:
        inv = inverse(mm)
        arr.append(inv)
    return torch.tensor(arr, dtype=torch.float32)

tmpten = torch.rand(1,10,2,2)
print(tmpten)
print(torch.inverse(tmpten))
print(tensor_inverse(tmpten))

I redid the function like below in tensors, but now it gets an error:

“RuntimeError: Output 1 of UnbindBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.”

#Redo Inverse using the tensor functions 
import torch

def tensordirect_inverse(m):
  #Determinant 
  print(m.shape)
  tmptensor1 = torch.mul(m[0][0],m[1][1]);
  tmptensor2 = torch.mul(m[0][1], m[1][0]);
  gettensorDeterminant = tmptensor1-tmptensor2

  #Inverse Math
  tmpinverter1 =  torch.divide(m[1][1],gettensorDeterminant);
  tmpinverter2 =  torch.divide(torch.mul(torch.tensor([-1]), m[0][1]), gettensorDeterminant)
  tmpinverter3 = torch.divide(torch.mul(torch.tensor([-1]), m[1][0]), gettensorDeterminant)
  tmpinverter4 = torch.divide(m[0][0],gettensorDeterminant);
  return torch.tensor([[tmpinverter1, tmpinverter2], [tmpinverter3, tmpinverter4]])

def tensor_tensor_inverse(m):
  sqz = torch.squeeze(m,axis=0) 
  for i, x in enumerate(sqz):
    newtensor = tensordirect_inverse(x)
    sqz[i] = newtensor
    #print(inv)
  return sqz.expand(1, -1, -1,-1)

#Test Functions for proper tensor usage 



tmpten = torch.rand(1,10,2,2)
#tensor_tensor_inverse(tmpten)
#print(tmpten)t
print(torch.inverse(tmpten))
print(tensor_tensor_inverse(tmpten))

Hi Matt!

I won’t comment on the correctness of your code, but let me note:

You are applying recursion to the Laplace expansion (also called
the minor or cofactor expansion) to calculate your determinant.
This algorithm is elegant, but numerically impractical for all but the
smallest matrices. For an n x n matrix, its computational cost (in
time) scales as n! – worse than exponential.

(Without using that “one weird trick,” the computational cost of things
like matrix multiplication, matrix inversion, and calculating determinants
is n^3.)

There is much discussion in the numerical literature about inverting
matrices and calculating determinants (and linear algebra, in general).
You should be using some form of Gaussian elimination.

(Also, it’s more efficient – and in my mind more straightforward – to
invert a matrix directly using Gaussian elimination, rather than by
forming the adjoint (adjugate) matrix. Lastly, if you want to solve a
set of linear equations for a single “right hand side,” you’re also better
off just applying Gaussian elimination directly to the problem, rather
than computing the full inverse matrix.)

Best.

K. Frank

Hey, Thanks for your note. I am using small matrices, and I am not sure this solves my problem. Are you saying I should retrain the model with inverse?

Hi Matt!

Then you should be fine (for some definition of small …). 5x5 matrices
should be fine, and even 10x10 or 15x15 matrices could work (but your
recursive algorithm has a cost for a 10x10 matrix that is as if you were
working with a 150x150 matrix, and a 15x15 matrix would effectively
correspond to the workload of a 10,000 x 10,000 matrix).

Again, I’m not commenting on the correctness of your code, but the
algorithm you have chosen is logically legitimate, albeit highly inefficient.

What specifically is your problem?

I’m not sure of your use case.

If your matrix in question is a trainable parameter, and only its inverse
is used in the forward pass, then, yes, it would be more straightforward
and cheaper to work directly with the inverse matrix as the parameter.

In principle, you shouldn’t even have to retrain your network. Just keep
all of your trained parameters as is, except for replacing the matrix by
its inverse, and modify the forward function to use the inverse matrix
directly (rather than inverting the original matrix as part of the forward
pass).

Of course, if you use both the original matrix and its inverse directly in
the forward pass, yes, you will need to invert the matrix as part of the
forward pass, and therefore port torch.inverse() (or some equivalent)
to your mobile platform.

Best.

K. Frank

Hi,

Yes, The original matrix and the reverse is used. This is the current thing I am trying to figure out on how to write a python function that will be used in the coreml conversion as the pytorch inverse function is currently not supported. But I am getting the error below and trying to figure out how to make the function work with it as the inverse function I wrote seems to give the same result.

“RuntimeError: Output 1 of UnbindBackward0 is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.”

Hi Matt!

You need to find where you are performing an inplace tensor operation
(and eliminate it, perhaps by working with a copy of the tensor in question).

Most likely, sqz[i] = newtensor is the offending line. Assignment to an
slice of a tensor is an inplace modification. Perhaps you could construct
a new empty (or zero) tensor of the same shape as sqz and assign to
slices of that new tensor.

Best.

K. Frank

I tried cloning this tensor, but it seems even that gives the error.

5 #Determinant
6 print(m.shape)
----> 7 newtensor = m.clone()
8 tmptensor1 = newtensor[0][0] * newtensor[1][1]; #torch.mul(m[0:0],m[1][1]);
9 tmptensor2 = torch.mul(m[0][1].clone(), m[1][0].clone());

I redid the function as follows, the tensor is copied at the source, but it seems the enumerate causes this issue. Is there another way to do that?

#Redo Inverse using the tensor functions 
import torch

def tensordirect_inverse(m):
  #Determinant 
  print(m.shape)
  newtensor = m.clone()
  tmptensor1 = newtensor[0][0] * newtensor[1][1]; #torch.mul(m[0:0],m[1][1]);
  tmptensor2 = torch.mul(m[0][1].clone(), m[1][0].clone());
  gettensorDeterminant = tmptensor1-tmptensor2

  #Inverse Math
  tmpinverter1 =  torch.divide(m[1][1],gettensorDeterminant);
  tmpinverter2 =  torch.divide(torch.mul(torch.tensor([-1]), m[0][1]), gettensorDeterminant)
  tmpinverter3 = torch.divide(torch.mul(torch.tensor([-1]), m[1][0]), gettensorDeterminant)
  tmpinverter4 = torch.divide(m[0][0],gettensorDeterminant);
  return torch.tensor([[tmpinverter1, tmpinverter2], [tmpinverter3, tmpinverter4]])

def tensor_tensor_inverse(m):
  tmptensor = m.clone()
  sqz = torch.squeeze(tmptensor,axis=0) 
  newsqz = sqz.clone()
  for i, x in enumerate(newsqz):
    newtensor = tensordirect_inverse(x.clone())
    newsqz[i] = newtensor
    #print(inv)
  return newsqz.expand(1, -1, -1,-1)

#Test Functions for proper tensor usage 



tmpten = torch.rand(1,10,2,2)
#tensor_tensor_inverse(tmpten)
#print(tmpten)t
print(torch.inverse(tmpten))
print(tensor_tensor_inverse(tmpten))

ok I think I figured it out that I should just reference the original clone tensor and not try to pass the sub tensor from the enumerate call.