What is torch.renorm

I’m confused about the doc of torch.renorm and can you tell me what use of renorm and how the example in doc as follows compute.

>>> x
tensor([[ 1.,  1.,  1.],
        [ 2.,  2.,  2.],
        [ 3.,  3.,  3.]])
>>> torch.renorm(x, 1, 0, 5)
tensor([[ 1.0000,  1.0000,  1.0000],
        [ 1.6667,  1.6667,  1.6667],
        [ 1.6667,  1.6667,  1.6667]])

I am also having trouble understanding torch.renorm. Here is an example which confuses me:

>>> import torch
>>> x = torch.randn(3, 3)
>>> torch.norm(x, p=2, dim=1)
tensor([2.2026, 2.2433, 1.0954])
>>> y = torch.renorm(x, p=2, dim=1, maxnorm=2)
>>> torch.norm(y, p=2, dim=1)
tensor([2.0249, 1.7163, 0.8764])

Since I have specified maxnorm to be 1, I expect all values of y norm to be less than or equal to 1 along dim=1 for p=2 but in this case, the first value is 2.0249. But if I do

>>> torch.norm(y, p=2, dim=0)
tensor([2.0000, 0.8703, 1.7483])

all values are less than the maxnorm. How does torch.renorm works and along what dimension does it renormalizes the array?

Something like a mathematical formulation of the renorm will be super helpful.

@ptrblck can this be a bug?

No, I don’t think it’s a bug and it seems you are using the dim argument wrong in renorm:

x = torch.randn(3, 3)
x.norm(p=2, dim=1)
# tensor([1.2906, 2.8629, 1.7181])
torch.renorm(x, p=2, dim=0, maxnorm=2).norm(p=2, dim=1)
# tensor([1.2906, 2.0000, 1.7181])

x = torch.randn(3, 3) * 5
x.norm(p=2, dim=1)
# tensor([14.8308,  5.6469,  4.5213])
torch.renorm(x, p=2, dim=0, maxnorm=2).norm(p=2, dim=1)
# tensor([2., 2., 2.])

as it’s:

the dimension to slice over to get the sub-tensors

I also made a numpy implementation of renorm to understand better about it.

import numpy as np

def renorm(x, p, dim, maxnorm):
    x_view = np.rollaxis(x, dim, 0)
    n = x.shape[dim]
    norms = []
    for i in range(n):
        norms.append(np.linalg.norm(x_view[i,:], ord=2))

    factors = []
    for norm in norms:
        if norm > maxnorm:
            factors.append(maxnorm/norm)
        else:
            factors.append(1)
    factors = np.array(factors)
    return x * factors.reshape(-1, 1), factors