I only have two GPUs, so I tested size == 1 and size == 2 using CUDA events. It looks like the forward pass of 2 GPUs are actually faster? I attached the code I am running below:

```
====== size = 1 ======
Iteration 0 forward latency is 340.7067565917969
Iteration 1 forward latency is 46.39555358886719
Iteration 2 forward latency is 46.37984085083008
Iteration 3 forward latency is 46.37712097167969
Iteration 4 forward latency is 46.3746223449707
Iteration 5 forward latency is 46.35868835449219
Iteration 6 forward latency is 46.370174407958984
Iteration 7 forward latency is 46.40425491333008
Iteration 8 forward latency is 46.36265563964844
Iteration 9 forward latency is 46.36454391479492
end - start = 0.7640056293457747
====== size = 2 ======
Iteration 0 forward latency is 336.1044616699219
Iteration 1 forward latency is 26.22003173828125
Iteration 2 forward latency is 27.49286460876465
Iteration 3 forward latency is 26.249248504638672
Iteration 4 forward latency is 26.69696044921875
Iteration 5 forward latency is 26.118335723876953
Iteration 6 forward latency is 27.30339241027832
Iteration 7 forward latency is 23.886367797851562
Iteration 8 forward latency is 23.869632720947266
Iteration 9 forward latency is 23.936511993408203
end - start = 0.5738828824833035
Iteration 0 forward latency is 312.13189697265625
Iteration 1 forward latency is 24.0633602142334
Iteration 2 forward latency is 23.685983657836914
Iteration 3 forward latency is 23.70742416381836
Iteration 4 forward latency is 23.703231811523438
Iteration 5 forward latency is 23.78976058959961
Iteration 6 forward latency is 23.779136657714844
Iteration 7 forward latency is 23.787424087524414
Iteration 8 forward latency is 23.791616439819336
Iteration 9 forward latency is 23.80246353149414
end - start = 2.9916703598573804
```

```
import math
import numbers
import os
import time
from functools import partial
from typing import Tuple
import numpy as np
import torch
import torch.distributed as dist
from opt_einsum import contract
from torch import nn
from torch.multiprocessing import set_start_method, Pool
class DifferenceOfGaussiansFFT(nn.Module):
def __init__(
self,
*,
img_height: int,
img_width: int,
min_sigma: int = 1,
max_sigma: int = 10,
sigma_bins: int = 50,
truncate: float = 5.0,
):
super(DifferenceOfGaussiansFFT, self).__init__()
self.img_height = img_height
self.img_width = img_width
self.signal_ndim = 2
self.sigma_list = np.concatenate(
[
np.linspace(min_sigma, max_sigma, sigma_bins),
[max_sigma + (max_sigma - min_sigma) / (sigma_bins - 1)],
]
)
sigmas = torch.from_numpy(self.sigma_list)
self.register_buffer("sigmas", sigmas)
# print("gaussian pyramid sigmas: ", len(sigmas), sigmas)
# accommodate largest filter
self.max_radius = int(truncate * max(sigmas) + 0.5)
max_bandwidth = 2 * self.max_radius + 1
# pad fft to prevent aliasing
padded_height = img_height + max_bandwidth - 1
padded_width = img_width + max_bandwidth - 1
# round up to next power of 2 for cheaper fft.
self.fft_height = 2 ** math.ceil(math.log2(padded_height))
self.fft_width = 2 ** math.ceil(math.log2(padded_width))
self.pad_input = nn.ConstantPad2d(
(0, self.fft_width - img_width, 0, self.fft_height - img_height), 0
)
self.f_gaussian_pyramid = []
kernel_pad = nn.ConstantPad2d(
# left, right, top, bottom
(0, self.fft_width - max_bandwidth, 0, self.fft_height - max_bandwidth),
0,
)
for i, s in enumerate(sigmas):
radius = int(truncate * s + 0.5)
width = 2 * radius + 1
kernel = torch_gaussian_kernel(width=width, sigma=s.item())
# this is to align all of the kernels so that the eventual fft shifts a fixed amount
center_pad_size = self.max_radius - radius
if center_pad_size > 0:
centered_kernel = nn.ConstantPad2d(center_pad_size, 0)(kernel)
else:
centered_kernel = kernel
padded_kernel = kernel_pad(centered_kernel)
f_kernel = torch.rfft(
padded_kernel, signal_ndim=self.signal_ndim, onesided=True
)
self.f_gaussian_pyramid.append(f_kernel)
self.f_gaussian_pyramid = nn.Parameter(
torch.stack(self.f_gaussian_pyramid, dim=0), requires_grad=False
)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
img_height, img_width = list(input.size())[-self.signal_ndim:]
assert (img_height, img_width) == (self.img_height, self.img_width)
padded_input = self.pad_input(input)
f_input = torch.rfft(padded_input, signal_ndim=self.signal_ndim, onesided=True)
f_gaussian_images = comp_mul(self.f_gaussian_pyramid, f_input)
gaussian_images = torch.irfft(
f_gaussian_images,
signal_ndim=self.signal_ndim,
onesided=True,
signal_sizes=padded_input.shape[1:],
)
# fft induces a shift so needs to be undone
gaussian_images = gaussian_images[
:, # batch dimension
:, # filter dimension
self.max_radius: self.img_height + self.max_radius,
self.max_radius: self.img_width + self.max_radius,
]
return gaussian_images
def torch_gaussian_kernel(
width: int = 21, sigma: int = 3, dim: int = 2
) -> torch.Tensor:
"""Gaussian kernel
Parameters
----------
width: bandwidth of the kernel
sigma: std of the kernel
dim: dimensions of the kernel (images -> 2)
Returns
-------
kernel : gaussian kernel
"""
if isinstance(width, numbers.Number):
width = [width] * dim
if isinstance(sigma, numbers.Number):
sigma = [sigma] * dim
kernel = 1
meshgrids = torch.meshgrid(
[torch.arange(size, dtype=torch.float32) for size in width]
)
for size, std, mgrid in zip(width, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= (
1
/ (std * math.sqrt(2 * math.pi))
* torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
)
# Make sure sum of values in gaussian kernel equals 1.
kernel = kernel / torch.sum(kernel)
return kernel
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i: i + n]
def comp_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Complex multiplies two complex 3d tensors
x = (x_real, x_im)
y = (y_real, y_im)
x*y = (x_real*y_real - x_im*y_im, x_real*y_im + x_im*y_real)
Last dimension is x2 with x[..., 0] real and x[..., 1] complex.
Dimensions (-3,-2) must be equal of both a and b must be the same.
Examples
________
>>> f_filters = torch.rand((20, 1024, 1024, 2))
>>> f_imgs = torch.rand((5, 1024, 1024, 2))
>>> f_filtered_imgs = comp_mul(f_filters, f_imgs)
Parameters
----------
x : Last dimension is (a,b) of a+ib
y : Last dimension is (a,b) of a+ib
Returns
-------
z : x*y
"""
# hadamard product of every filter against every batch image
op = partial(contract, "fuv,buv->bfuv")
assert x.shape[-1] == y.shape[-1] == 2
x_real, x_im = x.unbind(-1)
y_real, y_im = y.unbind(-1)
z = torch.stack(
[op(x_real, y_real) - op(x_im, y_im), op(x_real, y_im) + op(x_im, y_real)],
dim=-1,
)
return z
def run(rank, size):
with torch.no_grad():
img_tensor = torch.rand((1, 1000, 1000))
dog = DifferenceOfGaussiansFFT(
img_height=1000,
img_width=1000,
sigma_bins=48 // size,
max_sigma=30,
).to(rank, non_blocking=True)
for p in dog.parameters():
p.requires_grad = False
dog.eval()
torch.cuda.synchronize(rank)
dogs = []
start = time.monotonic()
s = torch.cuda.current_stream(rank)
e_start = torch.cuda.Event(enable_timing=True)
e_finish = torch.cuda.Event(enable_timing=True)
for i in range(10):
img_tensor = img_tensor.to(rank)
# torch.cuda.synchronize(rank)
s.record_event(e_start)
dogs.append(dog(img_tensor))
s.record_event(e_finish)
e_finish.synchronize()
print(f"Iteration {i} forward latency is {e_start.elapsed_time(e_finish)}")
end = time.monotonic()
print("end - start = ", end - start)
torch.cuda.synchronize(rank)
return dogs
def init_process(rank_size_fn, backend="nccl"):
rank, size, fn = rank_size_fn
""" Initialize the distributed environment. """
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend, rank=rank, world_size=size)
return fn(rank, size)
if __name__ == "__main__":
set_start_method("spawn")
size = 1
print("====== size = 1 ======")
pool = Pool(processes=size)
start = time.monotonic()
res = pool.map(init_process, [(i, size, run) for i in range(size)])
end = time.monotonic()
#print(end - start)
pool.close()
print("====== size = 2 ======")
size = 2
pool = Pool(processes=size)
start = time.monotonic()
res = pool.map(init_process, [(i, size, run) for i in range(size)])
end = time.monotonic()
#print(end - start)
pool.close()
# print(res)
