Pixel normalization through channels

I’m trying to reproduce GANSynth paper and they make use of Pixel Norm, a technique that computes the norm of a single pixel among all channels per sample:


How can I have this kind of normalization in PyTorch?

I tried torch.nn.functional.normalize(input, dim=1), but it normalizes through all batch samples at once. Pixel norm intends to normalize inside a single sample (similar to LayerNorm).

I think you could try to directly implement the normalization via:

b, c, h, w = 2, 3, 4, 4
x = torch.randn(b, c, h, w)

out = x  / (1/x.size(1) * (x**2).sum(dim=1, keepdims=True)).sqrt()

If you have any reference implementation you could compare these approaches and see if I misunderstood the formula.

2 Likes

I’ll check and then post the feedback here.

Update: It looks correct. Thanks!

Great and thanks for confirming!

Btw., note that each of these primitive operations would launch a separate CUDA kernel (in case you are using the GPU) so you might not see the best performance.
If you are using PyTorch >=1.12.0 you could try to torch.jit.script it and allow nvFuser to code generate fast kernels for your workload.

Here is an example which sees a ~3.8x speedup via the generated CUDA kernels:

def my_norm(x):
    out = x  / (1/x.size(1) * (x**2).sum(dim=1, keepdim=True)).sqrt()
    return out

iteration_count = 100
x = torch.randn(64, 3, 1024, 1024, device='cuda')

# Eager execution
# Perform warm-up iterations
for _ in range(3):
    output = my_norm(x)

# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
    output = my_norm(x)

# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))


# Scripted via nvFuser
my_norm_scripted = torch.jit.script(my_norm)

# Perform warm-up iterations
for _ in range(3):
    output = my_norm_scripted(x)

# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
    output = my_norm_scripted(x)

# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))

Output:

# Eager
Average iterations per second: 135.68

# Scripted with generated kernels from nvFuser
Average iterations per second: 515.90
2 Likes