Power mean in pytorch

Hi,
I am trying to implement the power mean in pytorch following this paper.
It is pretty straight forward to do it in numpy:

def gen_mean(vals, p):
    p = float(p)
    return np.power(
        np.mean(
            np.power(
                np.array(vals, dtype=complex),
                p),
            axis=0),
        1 / p
    )

Or in tensorflow:

def p_mean(values, p, n_toks):
    n_toks = tf.cast(tf.maximum(tf.constant(1.0), n_toks), tf.complex64)
    p_tf = tf.constant(float(p), dtype=tf.complex64)
    values = tf.cast(values, dtype=tf.complex64)
    res = tf.pow(
        tf.reduce_sum(
            tf.pow(values, p_tf),
            axis=1,
            keepdims=False
        ) / n_toks,
        1.0 / p_tf
    )
    return tf.real(res)

Source
However since pytorch do not allow complex number this seems really not trivial.
An example of a limitation is the geometric mean for negative numbers that do not seems possible in pytorch.
Am I missing something?

You could mock a complex number using two floats (one for magnitude and one for phase) this way the pow operation becomes easy but the mean operation is not that simple. Splitting the number by it’s real and imaginary part would make the mean easy but the pow would become hard.

What I would recommend is to use both representations (the most suitable for each operation) and transform them into each other. This would maybe not be that efficient, but it should work.

Given that you only have two phases (0 and pi) in your input, if p is fixed you can easily precompute that and just multiply the sign with the result. Then you multiply with the p-norm divided by the vector size to 1/p (or absorb that factor into the phase).

Best regards

Thomas

Edited P.S.: Looking at the paper authors’ implementation, they use min, max (limits as p approaches -/+inf, mean and 3rd-power mean, so it seems simple enough and no complex numbers involved (if you take power of 1/3 to be the inverse of power of 3).

Thanks @tom I am unsure to follow you. It is true that the paper only suggest using max/min/and odd power mean from 1 to 10.
However this does not solve the issues of having negative numbers as far I understand.
i.e. Taking the 3rd power mean of -9,-3 ( gen_mean([-9, -3],3) try to do np.power(-378,1/3) which solution is complex).
Would you mind elaborating?
Best,

Well, -(378**(1/3))is a root, too, so I’d start with that.
If not, you can precompute -1**(1/3) to (0.5+0.866j) and “outer” multiply 378**(1/3) (or whatever outcome you have by a two-tensor tensor([0.5, 0.866]) if it is negative and tensor([1.0, 0]) if the mean is positive.

Best regards

Thomas

Thanks @tom I think you’re right this solution will just do it.
Here is what it looks like by the way:

def p_mean_pytorch(tensor,power):
          mean_tensor = tensor.pow(power).mean(0)
          return (  (mean_tensor * mean_tensor.sign() ).pow(1/power) * mean_pow.sign() )
     

i.e.

p_mean_pytorch(torch.tensor([[-100,33,99],[39,9,-10000],[1,3,4],[0,0,0]]).float(), 3)
tensor([  -61.7249,    20.9335, -6299.6050])

I would have liked to get the same root resolution as their implementation of the paper but I guess this will do :slight_smile: . I will check the precomputing part later, the only challenging part is telling torch.ger to be conditional but I think I can do something with a mask.

 def power_mean_precompute_3(tensor):
     magical_number=torch.tensor([0.5,0.866]) # np.power(-1+0j,1/3)
     mean_tensor = tensor.pow(3).mean(0)
     mask = torch.le(mean_tensor,0.)
     le_result = torch.ger((mean_tensor * mask.float() * mean_tensor.sign()).pow(1/3),magical_number)
     ge_result = torch.ger((mean_tensor * (~mask).float()).pow(1/3),torch.tensor([1.,0]))
     return le_result+ge_result
power_mean_precompute_3(torch.tensor([[-100,33,99],[39,9,-10000],[1,3,4],[0,0,0]]).float())
tensor([[  30.8625,   53.4538],
        [  20.9335,    0.0000],
        [3149.8025, 5455.4580]])

gen_mean(torch.tensor([[-100,33,99],[39,9,-10000],[1,3,4],[0,0,0]]).numpy(),3)
array([  30.86246739  +53.45536157j,   20.93346287   +0.        j,
       3149.80160592+5455.61641521j])

It seems to be doing what intended but it should get a bit optimized.
Also, this won’t work for a batch as torch.ger won’t allow it but to adapt it shouldn’t be too hard.

I usually recommend broadcasting for outer products. Here you can combine with advanced indexing (.long() makes it an index instead of a mask):

# seq x batch
a = torch.tensor([[-100.,33,99],[39,9,-10000],[1,3,4],[0,0,0]])

magical_number=torch.tensor([[0.5,(0.75)**(0.5)],[1,0]]) # np.power(-1+0j,1/3), 1 ; keep out of function if you want precomputed..
mean_tensor = a.pow(3).mean(0)
magical_number[(mean_tensor>0).long()]*mean_tensor.abs().pow(1/3)[:,None]

It still is much slower than numpy, but it might not be too the overall bottleneck.

Best regards

Thomas

P.S.: Pro tip: Don’t do torch.tensor(...).float() or .cuda() or so, but always use torch.tensor(..., dtype=torch.float, device=...). It’s more efficient and once you add requires_grad, the latter gives you a leaf variable while the former does not.

2 Likes