How to export to Tensorflow when model contains torch.atan2()?


(Scott Hawley) #1

I want to export my model to Tensorflow, which as I understand it means I need to export to ONNX first. (is that right?)

My model contains a torch.atan2() operation, which is not on the list of supported operators for ONNX export. Indeed, the list for ONNX itself only includes atan but not atan2, but Tensorflow has atan2.

So, if this were numpy I could just write my own little “my_atan2()” function that calls atan() and then uses some kind of where() to decide how many factors of pi to add appropriately… maybe something like this?

def my_atan2(y, x):
    pi = torch.from_numpy(np.array([np.pi])).to(y.device, y.dtype)
    ans = torch.atan(y/x)
    ans = torch.where( (y>0)*(x<0), ans+pi, ans)   # upper left quadrant
    ans = torch.where( (y<0)*(x<0), ans+pi, ans)   # lower left quadrant
    # upper right quadrant and lower right quadrant, do nothing
    return ans

…But I’m guessing that won’t satisfy Autograd. And looking at the code for torch.atan2.backward()… yea I don’t understand what’s going on there.

Any suggestions?


(Scott Hawley) #2

Whoa, actually that seems to work! But now I have other functions that also throw “undefined” errors, such as torch.flip() not being supported.

And my new DIY implementation of torch.flip() relies on torch.arange() which throws an ONNX export error about “index” not existing.

…bah. This is no fun. Is ONNX still the only way to get to Tensorflow from PyTorch?

(This is for a web app. We may just go with Bokeh Server.)