Pixel normalization through channels

Great and thanks for confirming!

Btw., note that each of these primitive operations would launch a separate CUDA kernel (in case you are using the GPU) so you might not see the best performance.
If you are using PyTorch >=1.12.0 you could try to torch.jit.script it and allow nvFuser to code generate fast kernels for your workload.

Here is an example which sees a ~3.8x speedup via the generated CUDA kernels:

def my_norm(x):
    out = x  / (1/x.size(1) * (x**2).sum(dim=1, keepdim=True)).sqrt()
    return out

iteration_count = 100
x = torch.randn(64, 3, 1024, 1024, device='cuda')

# Eager execution
# Perform warm-up iterations
for _ in range(3):
    output = my_norm(x)

# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
    output = my_norm(x)

# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))


# Scripted via nvFuser
my_norm_scripted = torch.jit.script(my_norm)

# Perform warm-up iterations
for _ in range(3):
    output = my_norm_scripted(x)

# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
    output = my_norm_scripted(x)

# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("Average iterations per second: {:.2f}".format(iters_per_second))

Output:

# Eager
Average iterations per second: 135.68

# Scripted with generated kernels from nvFuser
Average iterations per second: 515.90
2 Likes