I’m trying to reproduce GANSynth paper and they make use of Pixel Norm, a technique that computes the norm of a single pixel among all channels per sample:

How can I have this kind of normalization in PyTorch?

I tried torch.nn.functional.normalize(input, dim=1), but it normalizes through all batch samples at once. Pixel norm intends to normalize inside a single sample (similar to LayerNorm).

I think you could try to directly implement the normalization via:

```
b, c, h, w = 2, 3, 4, 4
x = torch.randn(b, c, h, w)
out = x / (1/x.size(1) * (x**2).sum(dim=1, keepdims=True)).sqrt()
```

If you have any reference implementation you could compare these approaches and see if I misunderstood the formula.

1 Like

I’ll check and then post the feedback here.

Update: It looks correct. Thanks!

Great and thanks for confirming!

Btw., note that each of these primitive operations would launch a separate CUDA kernel (in case you are using the GPU) so you might not see the best performance.

If you are using PyTorch `>=1.12.0`

you could try to `torch.jit.script`

it and allow `nvFuser`

to code generate fast kernels for your workload.

Here is an example which sees a ~3.8x speedup via the generated CUDA kernels:

```
def my_norm(x):
out = x / (1/x.size(1) * (x**2).sum(dim=1, keepdim=True)).sqrt()
return out
iteration_count = 100
x = torch.randn(64, 3, 1024, 1024, device='cuda')
# Eager execution
# Perform warm-up iterations
for _ in range(3):
output = my_norm(x)
# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
output = my_norm(x)
# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))
# Scripted via nvFuser
my_norm_scripted = torch.jit.script(my_norm)
# Perform warm-up iterations
for _ in range(3):
output = my_norm_scripted(x)
# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
output = my_norm_scripted(x)
# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))
```

Output:

```
# Eager
Average iterations per second: 135.68
# Scripted with generated kernels from nvFuser
Average iterations per second: 515.90
```

1 Like