Best way to encode complex values?

Hi all,

I am a physicist and I use deep learning on physical systems, where usually the physics is linear/simple when using complex values. That’s why I try to use complex values ANN, and I already use a custom set of functions/layers to implement complex layers:

So far, all my functions work taking two arguments, one tensor for the real part, one for the imaginary part, I am now wondering if it is the best way to go.

I recently discovered that PyTorch does have one type of complex layer, as it allows complex FFTs (which is awesome by the way):
https://pytorch.org/docs/master/torch.html?highlight=fft#torch.fft

This takes a tensor with an additional last dimension of size 2.

My question is, is it better to keep two tensors as arguments, or one with an additional dimension?

The advantage of the 2 arguments scheme is that, as most of the complex functions simply require to make independent operations on the imaginary and real part, it allows passing directly the two tensors to the builtin (real) PyTorch function. The simplest example being the C-relu:

def complex_relu(input_r,input_i):
    return relu(input_r), relu(input_i)

The problem is that the syntax is then not simple (especially for cost functions as they take two complex vectors, then 4 tensors).

If I want to use the one tensor argument, as in torch.fft, I then have to separate the real and imaginary part:

def complex_relu(complex_input):
    return torch.stack(relu(complex_input[...,0]), relu(complex_input[...,1]), dim = -1)

Is it really optimal? It seems that is would make copies of the sliced tensors, is it memory efficient? Is there a better way to do it?

As far as I understand the bug report tracking complex, the current favourite implementations are through the external modules pytorch-cpu-strided-complex and the cuda equivalent (note that you need to import the cpp submodule). For me, they seem to have more than a few rough edges at the moment, but I’d imagine that if you’re keen on helping out, noone will complain.
One thing I didn’t find is how to get zero-copy access to the real and imaginary part. I would imagine that that should be possible (because that’s how they’re saved in memory, after all) but .real and .imag stay in the complex datatypes rather than returning the real ones.

Best regards

Thomas

Hi Tom and thank you for your quick reply,

Thanks for the links, but frankly, this seems to be out of my league.
Moreover, as far as I understand, pytorch-cpu-strided-complex would help me get complex type tensors, but not so much to getting complex operations. Am I wrong about that?
And without complex matrix multiplication, for instance, it is no use to me.

I am good for now using my rough Python implementation, while I am perfectly aware of the fact that it is far from optimal.

My question is them, from the two methods I proposed, which is the best? Is the slicing thing in the second forward implementation reasonable or does it create useless memory usage?

For the functions above. I personally prefer to use explicit functions instead of advanced indexing to be sure no copy happens. In you case, you can replace complex_input[...,0] by complex_input.select(-1, 0). .select() never does copy and is really cheap (cheaper than advanced indexing) so that should be optional.

1 Like

Sounds like a great solution to my issue with slicing, I will test it!

Well, I did some tests and select() does not seem so cheap, here is what I tested:

from torch.nn.functional import relu
import torch

def complex_relu(input_r,input_i):
    return relu(input_r), relu(input_i)

def complex_relu2(complex_input):
    return torch.stack((relu(complex_input[...,0]), relu(complex_input[...,1])), dim = -1)


def complex_relu3(complex_input):
    complex_input[...,0] = relu(complex_input.select(-1, 0))
    complex_input[...,1] = relu(complex_input.select(-1, 1))
    return complex_input

def complex_relu4(complex_input):
    return torch.stack((relu(complex_input.select(-1, 0)), relu(complex_input[...,1])), dim = -1)

device = torch.device("cuda:0" )

n = 1000
X_r = torch.randn(n,n).to(device)
X_i = torch.randn(n,n).to(device)
X = torch.stack((X_r,X_i),dim=-1).to(device)
%timeit complex_relu(X_r,X_i)

33.1 µs ± 6.04 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit complex_relu2(X)

148 µs ± 32.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit complex_relu3(X)

130 µs ± 12.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit complex_relu4(X)

148 µs ± 33.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

So, any thoughts on how to have the two layer system to be as fast as the two separate arguments?

Obviously, this one is a good option for relu:

def complex_relu5(complex_input):
    return torch.clamp(complex_input,min = 0)

31.6 µs ± 1.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Still, for other types of functions, I would still need to slice/stack real and imaginary part.

Hi,

I would guess that most of the runtime there comes from the .stack() function, not how you do the slicing.

Also be careful if you’re doing timings on cuda because the api is asynchronous. So you need to introduce torch.cuda.synchronize() to make sure you measure actual runtimes.

Thanks for the tip, I did not know.

Stacking slows down the code but slicing too:

A = torch.randn(20000).to(device)
%timeit AA = A
%timeit AA = A[:-1]
%timeit A[:-1] = A[:-1]

18.4 ns ± 3.15 ns per loop (mean ± std. dev. of 7 runs, 10000000 loops each)
1.7 µs ± 3.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
14.6 µs ± 2.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Is that the best I can do?

Moreover, I discovered that there is now complex value related functions with PyTorch!
https://pytorch.org/docs/stable/torch.html

We find torch.angle(), torch.real() and torch.imag().

So I upgraded PyTorch to the latest (1.5) on my calculation server and ran the usage example from the official website for the torch.angle() function:

torch.angle(torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]))*180/3.14159

and got:

RuntimeError: Could not infer dtype of complex

Instead of the results from PyTorch documentation:

tensor([ 135.,  135,  -45])

If there is still no official complex tensor support, how one can use those functions?

We are starting to do so. But not many functions are implemented yet I’m afraid.

Is that the best I can do?

1 micro second is actually quite good already for pytorch standard. This is basically only the overhead of the framework. You won’t find any pytorch function that is cheaper than that.

Good to know, thanks!

About the complex tensors, as it is on the official documentation, can anybody tell me how to reproduce the output of the example?

You’ll need to install the extension mentioned earlier, https://gitlab.com/pytorch-complex/pytorch-cpu-strided-complex , first, before the complex operations work. We’re planning to move them into PyTorch, maybe even by 1.5, so you don’t have to do a complicated installation step first.

I did (though I had to modify the .cpp that seems to use an older version of the API), but even after a successful compilation and installation, it fails the tests. Never mind, if it goes in the next release, I am good with it. Thanks!

Will it be only CPU though?

As of 1.7.0, complex numbers largely work as expected, but a substantial number of functions aren’t implemented with CUDA acceleration or backwards propagation.
However, if you’re happy to take your own derivatives and not use the GPU acceleration, it works great.

Also when you find an unimplemented feature you need, check the issues on GitHub first, but feel free to create an issue.

For a research project in physics I am working on, I figured out that if you want to use sequential layers, the best way to define a custom activation function or layer in your model is to create a ‘Module subclass.’ This enables you to place a custom function directly into the nn.Sequential() layers. I could not find this very easily for sequential layers, let alone a module list due to the way these layers function as direct inputs to the next function; I only saw this for manually applying a linear layer to an input, defining this as one variable, then inputing that new variable directly into another (e.g. activation function), for output. (e.g. x = layer1(x), x = activation1(x), x = layer2(x), etc.).

Create a Module subclass for your split complex ReLU activation function

class complex_ReLU(nn.Module):
def forward(self, x):
return nn.ReLU()(x.real) + 1.j * nn.ReLU()(x.imag)

Define Model

class find_matrices(nn.Module):
def init(self,n_d,n_a,n_c):
super(find_matrices,self).init()

    A = [ nn.Sequential( nn.Linear(in_features = n_d*n_d, out_features = n_d*n_d, bias = True, dtype=torch.complex128),
                         complex_ReLU(),
                         nn.Linear(in_features = n_d*n_d, out_features = n_d*n_d, bias = True, dtype=torch.complex128),
                         complex_ReLU(),
                         nn.Linear(in_features = n_d*n_d, out_features = n_d*n_d, bias = True, dtype=torch.complex128) )  for _ in range(n_a) ]
    
    self.As = nn.ModuleList(A)
    
    B = [ nn.Sequential( nn.Linear(in_features = n_a, out_features = n_a, bias = True, dtype=torch.complex128),
                         complex_ReLU(),
                         nn.Linear(in_features = n_a, out_features = n_a, bias = True, dtype=torch.complex128),
                         complex_ReLU(),
                         nn.Linear(in_features = n_a, out_features = n_a, bias = True, dtype=torch.complex128) ) for _ in range(n_c) ]
    

    self.Bs = nn.ModuleList(B)
    
    self.n_a = n_a
    self.n_d = n_d
    self.n_c = n_c

def forward(self, x, c):
    a = []
    for i in range(self.n_gen):
        a.append( ( self.As[i](x[i].flatten()) ).reshape(self.n_dim,self.n_dim)  )
        
    b = torch.empty((self.n_com,self.n_gen))
    for i in range(self.n_com):
        b[i,:] = ( self.Bs[i](c[i].flatten()) ).reshape(1,self.n_gen)

    return a , b

model = find_matrices(n_d,n_g,n_c).to(device)

Hope this helps!