Does it ever make sense to try model parallelism even if the model fits?

I have a model that fits perfectly fine on any one of my GPUs (4x1080Tis) but I had the bright idea that maybe I could speed up a forward pass (at inference time) by partitioning up one of the layers (a very “tall” Conv2d - i.e. >20 output channels) across all of my GPUs. So I used DDP to map across my GPUs and surprisingly (or maybe not?) forward pass actually gets slower with increasing number of GPUs. Is this to be expected?

I’m not an expert on the execution pipeline on GPUs but is it the case that any individual CUDA kernel (e.g. my “tall” Conv2d) gets executed in parallel? I’m guessing that that’s my issue - that the layer I’m partitioning up already gets executed in parallel and the scatter/gather just adds copy (and process instantiation) latencies.

1 Like

Hey @makslevental could you please elaborate on how did you manage to use DDP after splitting one layer? Does it mean each DDP process now sees a different model? Or is it true that each DDP process no longer has exclusive access to its own GPU?

and surprisingly (or maybe not?) forward pass actually gets slower with increasing number of GPUs. Is this to be expected?

Since with the split, each forward pass will do cross-device communications, so it is possible to see slowdowns.

@mrshenli I split my network across the layer boundary so yes each DDP now sees a different model - imagine instead of a 48 output channel Conv2d applying 4x12 output channel Conv2ds.

why is there cross-talk? I with with torch.no_grad() and for p in model.parameters(): p.requires_grad = False in every run.
.

I see. Is this forward only for inference or do you also run backward for training? If it is the latter, this might kill the correctness of DDP, as DDP expects the model in each process to be exactly the same, otherwise, the AllReduce communication across DDP process could mess up the gradients.

why is there cross-talk? I with with torch.no_grad() and for p in model.parameters(): p.requires_grad = False in every run .

Looks like you are doing inference instead of training? In this case, don’t you need to somehow gather/combine the outputs from the four different Conv2d layers from 4 different DDP processes? Otherwise, how did you get the final inference result?

BTW, since this is inference only, why do you need DDP?

Sorry actually I just realized I’ve completely misspoken. I’m not wrapping my model in DDP. I was planning on doing this and then I realized it replicates across nodes where as I need to send distinct (but related) models to nodes.

Yes this is correct, I do a map and then I plan on doing a concat to reconstruct the output channels as if they all came from the same Conv2d. I think you can basically get the idea from this code snippet

def run(rank, size):
    with torch.no_grad():
        image_pth = Path(os.path.dirname(os.path.realpath(__file__))) / Path(
            "../simulation/screenshot.png"
        )

        screenshot = SimulPLIF(img_path=image_pth, num_repeats=1, load_truth=False)
        img_height, img_width = screenshot[0].squeeze(0).numpy().shape

        from nn_dog import PIN_MEMORY

        train_dataloader = DataLoader(screenshot, batch_size=1, pin_memory=PIN_MEMORY)
        dog = DifferenceOfGaussiansFFT(
            img_height=img_height,
            img_width=img_width,
            sigma_bins=48 // size,
            max_sigma=30,
        ).to(rank, non_blocking=PIN_MEMORY)
        for p in dog.parameters():
            p.requires_grad = False
        dog.eval()
        torch.cuda.synchronize(rank)

        dogs = []
        for i in range(10):
            img_tensor = next(iter(train_dataloader))
            img_tensor = img_tensor.to(rank)
            torch.cuda.synchronize(rank)
            dogs.append(dog(img_tensor))
        return dogs


def init_process(rank_size_fn, backend="nccl"):
    rank, size, fn = rank_size_fn
    """ Initialize the distributed environment. """
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    return fn(rank, size)


if __name__ == "__main__":
    set_start_method("spawn")

    size = 4
    pool = Pool(processes=size)
    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    print(end - start)

I see. How did you measure the latency of the forward pass? Did you use elapsed_time from CUDA events to wrap dog(img_tensor)?

it’s a rough measure but it’s right there

    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    print(end - start)

then since in run i repeat 10 times i reason that i’m amortizing process instantiation across those 10 inference passes. and so between 1 and 4 GPUs end-start just grows.

Could you please wrap the above code with some time measure and check how much percentage does it contribute to the total delay of end-start?

BTW, you might be able to avoid the torch.cuda.synchronize(rank) above, as the copy (tensor.to) should have synchronize the destination device properly.

okay i removed the sync and enclosed the loop in time.monotonic.

2 gpus:
1.2319540430326015
1.3151386808604002
total: 6.296403981978074

3 gpus:
1.1622967889998108
1.3194731972180307
1.3116707119625062
total: 5.875259702792391

4 gpus:
1.1516663811635226
1.4554521720856428
1.76222850009799
1.8313195349182934
total: 6.504983321996406

i ran this several times it’s consistent.

The large variance (1.15 vs 1.83) seems to suggest there are some sort of contention. I wonder if that is caused by the data loading. What if we use CUDA event elapsed_time to measure dog(img_tensor)? Note that time. monotonic() does not guarantee to give the correct measurements as there could still be CUDA ops pending in stream.

Or is it possible to get a self-contained example that we can investigate locally? E.g., using torch.rand to create random inputs instead loading from screenshot.png

this is self-contained; except for the standard stuff (numpy, pytorch) you only need opt_einsum


import math
import numbers
import os
import time
from functools import partial
from typing import Tuple

import numpy as np
import torch
import torch.distributed as dist
from opt_einsum import contract
from torch import nn
from torch.multiprocessing import set_start_method, Pool

class DifferenceOfGaussiansFFT(nn.Module):
    def __init__(
            self,
            *,
            img_height: int,
            img_width: int,
            min_sigma: int = 1,
            max_sigma: int = 10,
            sigma_bins: int = 50,
            truncate: float = 5.0,
    ):
        super(DifferenceOfGaussiansFFT, self).__init__()
        self.img_height = img_height
        self.img_width = img_width
        self.signal_ndim = 2

        self.sigma_list = np.concatenate(
            [
                np.linspace(min_sigma, max_sigma, sigma_bins),
                [max_sigma + (max_sigma - min_sigma) / (sigma_bins - 1)],
            ]
        )
        sigmas = torch.from_numpy(self.sigma_list)
        self.register_buffer("sigmas", sigmas)
        # print("gaussian pyramid sigmas: ", len(sigmas), sigmas)

        # accommodate largest filter
        self.max_radius = int(truncate * max(sigmas) + 0.5)
        max_bandwidth = 2 * self.max_radius + 1
        # pad fft to prevent aliasing
        padded_height = img_height + max_bandwidth - 1
        padded_width = img_width + max_bandwidth - 1
        # round up to next power of 2 for cheaper fft.
        self.fft_height = 2 ** math.ceil(math.log2(padded_height))
        self.fft_width = 2 ** math.ceil(math.log2(padded_width))
        self.pad_input = nn.ConstantPad2d(
            (0, self.fft_width - img_width, 0, self.fft_height - img_height), 0
        )

        self.f_gaussian_pyramid = []
        kernel_pad = nn.ConstantPad2d(
            # left, right, top, bottom
            (0, self.fft_width - max_bandwidth, 0, self.fft_height - max_bandwidth),
            0,
        )
        for i, s in enumerate(sigmas):
            radius = int(truncate * s + 0.5)
            width = 2 * radius + 1
            kernel = torch_gaussian_kernel(width=width, sigma=s.item())

            # this is to align all of the kernels so that the eventual fft shifts a fixed amount
            center_pad_size = self.max_radius - radius
            if center_pad_size > 0:
                centered_kernel = nn.ConstantPad2d(center_pad_size, 0)(kernel)
            else:
                centered_kernel = kernel

            padded_kernel = kernel_pad(centered_kernel)

            f_kernel = torch.rfft(
                padded_kernel, signal_ndim=self.signal_ndim, onesided=True
            )
            self.f_gaussian_pyramid.append(f_kernel)

        self.f_gaussian_pyramid = nn.Parameter(
            torch.stack(self.f_gaussian_pyramid, dim=0), requires_grad=False
        )

    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        img_height, img_width = list(input.size())[-self.signal_ndim:]
        assert (img_height, img_width) == (self.img_height, self.img_width)

        padded_input = self.pad_input(input)
        f_input = torch.rfft(padded_input, signal_ndim=self.signal_ndim, onesided=True)
        f_gaussian_images = comp_mul(self.f_gaussian_pyramid, f_input)
        gaussian_images = torch.irfft(
            f_gaussian_images,
            signal_ndim=self.signal_ndim,
            onesided=True,
            signal_sizes=padded_input.shape[1:],
        )

        # fft induces a shift so needs to be undone
        gaussian_images = gaussian_images[
                          :,  # batch dimension
                          :,  # filter dimension
                          self.max_radius: self.img_height + self.max_radius,
                          self.max_radius: self.img_width + self.max_radius,
                          ]

        return gaussian_images


def torch_gaussian_kernel(
        width: int = 21, sigma: int = 3, dim: int = 2
) -> torch.Tensor:
    """Gaussian kernel

    Parameters
    ----------
    width: bandwidth of the kernel
    sigma: std of the kernel
    dim: dimensions of the kernel (images -> 2)

    Returns
    -------
    kernel : gaussian kernel

    """

    if isinstance(width, numbers.Number):
        width = [width] * dim
    if isinstance(sigma, numbers.Number):
        sigma = [sigma] * dim
    kernel = 1
    meshgrids = torch.meshgrid(
        [torch.arange(size, dtype=torch.float32) for size in width]
    )
    for size, std, mgrid in zip(width, sigma, meshgrids):
        mean = (size - 1) / 2
        kernel *= (
                1
                / (std * math.sqrt(2 * math.pi))
                * torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
        )

    # Make sure sum of values in gaussian kernel equals 1.
    kernel = kernel / torch.sum(kernel)
    return kernel


def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i: i + n]


def comp_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Complex multiplies two complex 3d tensors

    x = (x_real, x_im)
    y = (y_real, y_im)
    x*y = (x_real*y_real - x_im*y_im, x_real*y_im + x_im*y_real)

    Last dimension is x2 with x[..., 0] real and x[..., 1] complex.
    Dimensions (-3,-2) must be equal of both a and b must be the same.

    Examples
    ________
    >>> f_filters = torch.rand((20, 1024, 1024, 2))
    >>> f_imgs = torch.rand((5, 1024, 1024, 2))
    >>> f_filtered_imgs = comp_mul(f_filters, f_imgs)

    Parameters
    ----------
    x : Last dimension is (a,b) of a+ib
    y : Last dimension is (a,b) of a+ib

    Returns
    -------
    z : x*y

    """

    # hadamard product of every filter against every batch image
    op = partial(contract, "fuv,buv->bfuv")
    assert x.shape[-1] == y.shape[-1] == 2
    x_real, x_im = x.unbind(-1)
    y_real, y_im = y.unbind(-1)
    z = torch.stack(
        [op(x_real, y_real) - op(x_im, y_im), op(x_real, y_im) + op(x_im, y_real)],
        dim=-1,
    )
    return z


def run(rank, size):
    with torch.no_grad():
        img_tensor = torch.rand((1, 1, 1000, 1000))

        dog = DifferenceOfGaussiansFFT(
            img_height=1000,
            img_width=1000,
            sigma_bins=48 // size,
            max_sigma=30,
        ).to(rank, non_blocking=True)
        for p in dog.parameters():
            p.requires_grad = False
        dog.eval()
        torch.cuda.synchronize(rank)

        dogs = []
        start = time.monotonic()
        for i in range(10):
            img_tensor = img_tensor.to(rank)
            # torch.cuda.synchronize(rank)
            dogs.append(dog(img_tensor))
        end = time.monotonic()
        print(end - start)
        return dogs


def init_process(rank_size_fn, backend="nccl"):
    rank, size, fn = rank_size_fn
    """ Initialize the distributed environment. """
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    return fn(rank, size)


if __name__ == "__main__":
    set_start_method("spawn")

    size = 2
    pool = Pool(processes=size)
    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    print(end - start)
    pool.close()

    size = 3
    pool = Pool(processes=size)
    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    print(end - start)
    pool.close()

    size = 4
    pool = Pool(processes=size)
    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    print(end - start)
    pool.close()
    # print(res)

thanks for helping me with this btw!

1 Like

I conda installed opt_einsum but hit the following error. Is there a specific version of opt_einsum that I should use? The installed one is opt_einsum-3.2.1.

Traceback (most recent call last):
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/multiprocessing/pool.py", line 48, in mapstar
    return list(map(*args))
  File "/scratch/shenli/pytorch/test.py", line 229, in init_process
    return fn(rank, size)
  File "/scratch/shenli/pytorch/test.py", line 215, in run
    dogs.append(dog(img_tensor))
  File "/scratch/shenli/pytorch/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/scratch/shenli/pytorch/test.py", line 89, in forward
    f_gaussian_images = comp_mul(self.f_gaussian_pyramid, f_input)
  File "/scratch/shenli/pytorch/test.py", line 185, in comp_mul
    [op(x_real, y_real) - op(x_im, y_im), op(x_real, y_im) + op(x_im, y_real)],
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/site-packages/opt_einsum/contract.py", line 473, in contract
    operands, contraction_list = contract_path(*operands,
  File "/private/home/shenli/local/miniconda/envs/torchdev/lib/python3.8/site-packages/opt_einsum/contract.py", line 222, in contract_path
    raise ValueError("Einstein sum subscript '{}' does not contain the "
ValueError: Einstein sum subscript 'buv' does not contain the correct number of indices for operand 1.


it’s because i got the dimensions of img_tensor wrong. i guess it should be (1,1000,1000).

1 Like

I only have two GPUs, so I tested size == 1 and size == 2 using CUDA events. It looks like the forward pass of 2 GPUs are actually faster? I attached the code I am running below:

====== size = 1  ======
Iteration 0 forward latency is 340.7067565917969
Iteration 1 forward latency is 46.39555358886719
Iteration 2 forward latency is 46.37984085083008
Iteration 3 forward latency is 46.37712097167969
Iteration 4 forward latency is 46.3746223449707
Iteration 5 forward latency is 46.35868835449219
Iteration 6 forward latency is 46.370174407958984
Iteration 7 forward latency is 46.40425491333008
Iteration 8 forward latency is 46.36265563964844
Iteration 9 forward latency is 46.36454391479492
end - start =  0.7640056293457747
====== size = 2  ======
Iteration 0 forward latency is 336.1044616699219
Iteration 1 forward latency is 26.22003173828125
Iteration 2 forward latency is 27.49286460876465
Iteration 3 forward latency is 26.249248504638672
Iteration 4 forward latency is 26.69696044921875
Iteration 5 forward latency is 26.118335723876953
Iteration 6 forward latency is 27.30339241027832
Iteration 7 forward latency is 23.886367797851562
Iteration 8 forward latency is 23.869632720947266
Iteration 9 forward latency is 23.936511993408203
end - start =  0.5738828824833035
Iteration 0 forward latency is 312.13189697265625
Iteration 1 forward latency is 24.0633602142334
Iteration 2 forward latency is 23.685983657836914
Iteration 3 forward latency is 23.70742416381836
Iteration 4 forward latency is 23.703231811523438
Iteration 5 forward latency is 23.78976058959961
Iteration 6 forward latency is 23.779136657714844
Iteration 7 forward latency is 23.787424087524414
Iteration 8 forward latency is 23.791616439819336
Iteration 9 forward latency is 23.80246353149414
end - start =  2.9916703598573804
import math
import numbers
import os
import time
from functools import partial
from typing import Tuple

import numpy as np
import torch
import torch.distributed as dist
from opt_einsum import contract
from torch import nn
from torch.multiprocessing import set_start_method, Pool

class DifferenceOfGaussiansFFT(nn.Module):
    def __init__(
            self,
            *,
            img_height: int,
            img_width: int,
            min_sigma: int = 1,
            max_sigma: int = 10,
            sigma_bins: int = 50,
            truncate: float = 5.0,
    ):
        super(DifferenceOfGaussiansFFT, self).__init__()
        self.img_height = img_height
        self.img_width = img_width
        self.signal_ndim = 2

        self.sigma_list = np.concatenate(
            [
                np.linspace(min_sigma, max_sigma, sigma_bins),
                [max_sigma + (max_sigma - min_sigma) / (sigma_bins - 1)],
            ]
        )
        sigmas = torch.from_numpy(self.sigma_list)
        self.register_buffer("sigmas", sigmas)
        # print("gaussian pyramid sigmas: ", len(sigmas), sigmas)

        # accommodate largest filter
        self.max_radius = int(truncate * max(sigmas) + 0.5)
        max_bandwidth = 2 * self.max_radius + 1
        # pad fft to prevent aliasing
        padded_height = img_height + max_bandwidth - 1
        padded_width = img_width + max_bandwidth - 1
        # round up to next power of 2 for cheaper fft.
        self.fft_height = 2 ** math.ceil(math.log2(padded_height))
        self.fft_width = 2 ** math.ceil(math.log2(padded_width))
        self.pad_input = nn.ConstantPad2d(
            (0, self.fft_width - img_width, 0, self.fft_height - img_height), 0
        )

        self.f_gaussian_pyramid = []
        kernel_pad = nn.ConstantPad2d(
            # left, right, top, bottom
            (0, self.fft_width - max_bandwidth, 0, self.fft_height - max_bandwidth),
            0,
        )
        for i, s in enumerate(sigmas):
            radius = int(truncate * s + 0.5)
            width = 2 * radius + 1
            kernel = torch_gaussian_kernel(width=width, sigma=s.item())

            # this is to align all of the kernels so that the eventual fft shifts a fixed amount
            center_pad_size = self.max_radius - radius
            if center_pad_size > 0:
                centered_kernel = nn.ConstantPad2d(center_pad_size, 0)(kernel)
            else:
                centered_kernel = kernel

            padded_kernel = kernel_pad(centered_kernel)

            f_kernel = torch.rfft(
                padded_kernel, signal_ndim=self.signal_ndim, onesided=True
            )
            self.f_gaussian_pyramid.append(f_kernel)

        self.f_gaussian_pyramid = nn.Parameter(
            torch.stack(self.f_gaussian_pyramid, dim=0), requires_grad=False
        )

    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        img_height, img_width = list(input.size())[-self.signal_ndim:]
        assert (img_height, img_width) == (self.img_height, self.img_width)

        padded_input = self.pad_input(input)
        f_input = torch.rfft(padded_input, signal_ndim=self.signal_ndim, onesided=True)
        f_gaussian_images = comp_mul(self.f_gaussian_pyramid, f_input)
        gaussian_images = torch.irfft(
            f_gaussian_images,
            signal_ndim=self.signal_ndim,
            onesided=True,
            signal_sizes=padded_input.shape[1:],
        )

        # fft induces a shift so needs to be undone
        gaussian_images = gaussian_images[
                          :,  # batch dimension
                          :,  # filter dimension
                          self.max_radius: self.img_height + self.max_radius,
                          self.max_radius: self.img_width + self.max_radius,
                          ]

        return gaussian_images


def torch_gaussian_kernel(
        width: int = 21, sigma: int = 3, dim: int = 2
) -> torch.Tensor:
    """Gaussian kernel

    Parameters
    ----------
    width: bandwidth of the kernel
    sigma: std of the kernel
    dim: dimensions of the kernel (images -> 2)

    Returns
    -------
    kernel : gaussian kernel

    """

    if isinstance(width, numbers.Number):
        width = [width] * dim
    if isinstance(sigma, numbers.Number):
        sigma = [sigma] * dim
    kernel = 1
    meshgrids = torch.meshgrid(
        [torch.arange(size, dtype=torch.float32) for size in width]
    )
    for size, std, mgrid in zip(width, sigma, meshgrids):
        mean = (size - 1) / 2
        kernel *= (
                1
                / (std * math.sqrt(2 * math.pi))
                * torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
        )

    # Make sure sum of values in gaussian kernel equals 1.
    kernel = kernel / torch.sum(kernel)
    return kernel


def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i: i + n]


def comp_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Complex multiplies two complex 3d tensors

    x = (x_real, x_im)
    y = (y_real, y_im)
    x*y = (x_real*y_real - x_im*y_im, x_real*y_im + x_im*y_real)

    Last dimension is x2 with x[..., 0] real and x[..., 1] complex.
    Dimensions (-3,-2) must be equal of both a and b must be the same.

    Examples
    ________
    >>> f_filters = torch.rand((20, 1024, 1024, 2))
    >>> f_imgs = torch.rand((5, 1024, 1024, 2))
    >>> f_filtered_imgs = comp_mul(f_filters, f_imgs)

    Parameters
    ----------
    x : Last dimension is (a,b) of a+ib
    y : Last dimension is (a,b) of a+ib

    Returns
    -------
    z : x*y

    """

    # hadamard product of every filter against every batch image
    op = partial(contract, "fuv,buv->bfuv")
    assert x.shape[-1] == y.shape[-1] == 2
    x_real, x_im = x.unbind(-1)
    y_real, y_im = y.unbind(-1)
    z = torch.stack(
        [op(x_real, y_real) - op(x_im, y_im), op(x_real, y_im) + op(x_im, y_real)],
        dim=-1,
    )
    return z


def run(rank, size):
    with torch.no_grad():
        img_tensor = torch.rand((1, 1000, 1000))

        dog = DifferenceOfGaussiansFFT(
            img_height=1000,
            img_width=1000,
            sigma_bins=48 // size,
            max_sigma=30,
        ).to(rank, non_blocking=True)
        for p in dog.parameters():
            p.requires_grad = False
        dog.eval()
        torch.cuda.synchronize(rank)

        dogs = []
        start = time.monotonic()
        s = torch.cuda.current_stream(rank)
        e_start = torch.cuda.Event(enable_timing=True)
        e_finish = torch.cuda.Event(enable_timing=True)
        for i in range(10):
            img_tensor = img_tensor.to(rank)
            # torch.cuda.synchronize(rank)
            s.record_event(e_start)
            dogs.append(dog(img_tensor))
            s.record_event(e_finish)
            e_finish.synchronize()
            print(f"Iteration {i} forward latency is {e_start.elapsed_time(e_finish)}")
        end = time.monotonic()
        print("end - start = ", end - start)
        torch.cuda.synchronize(rank)
        return dogs


def init_process(rank_size_fn, backend="nccl"):
    rank, size, fn = rank_size_fn
    """ Initialize the distributed environment. """
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    return fn(rank, size)


if __name__ == "__main__":
    set_start_method("spawn")

    size = 1
    print("====== size = 1  ======")
    pool = Pool(processes=size)
    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    #print(end - start)
    pool.close()

    print("====== size = 2  ======")

    size = 2
    pool = Pool(processes=size)
    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    #print(end - start)
    pool.close()
    # print(res)
import math
import numbers
import os
import time
from functools import partial
from typing import Tuple

import numpy as np
import torch
import torch.distributed as dist
from opt_einsum import contract
from torch import nn
from torch.multiprocessing import set_start_method, Pool

class DifferenceOfGaussiansFFT(nn.Module):
    def __init__(
            self,
            *,
            img_height: int,
            img_width: int,
            min_sigma: int = 1,
            max_sigma: int = 10,
            sigma_bins: int = 50,
            truncate: float = 5.0,
    ):
        super(DifferenceOfGaussiansFFT, self).__init__()
        self.img_height = img_height
        self.img_width = img_width
        self.signal_ndim = 2

        self.sigma_list = np.concatenate(
            [
                np.linspace(min_sigma, max_sigma, sigma_bins),
                [max_sigma + (max_sigma - min_sigma) / (sigma_bins - 1)],
            ]
        )
        sigmas = torch.from_numpy(self.sigma_list)
        self.register_buffer("sigmas", sigmas)
        # print("gaussian pyramid sigmas: ", len(sigmas), sigmas)

        # accommodate largest filter
        self.max_radius = int(truncate * max(sigmas) + 0.5)
        max_bandwidth = 2 * self.max_radius + 1
        # pad fft to prevent aliasing
        padded_height = img_height + max_bandwidth - 1
        padded_width = img_width + max_bandwidth - 1
        # round up to next power of 2 for cheaper fft.
        self.fft_height = 2 ** math.ceil(math.log2(padded_height))
        self.fft_width = 2 ** math.ceil(math.log2(padded_width))
        self.pad_input = nn.ConstantPad2d(
            (0, self.fft_width - img_width, 0, self.fft_height - img_height), 0
        )

        self.f_gaussian_pyramid = []
        kernel_pad = nn.ConstantPad2d(
            # left, right, top, bottom
            (0, self.fft_width - max_bandwidth, 0, self.fft_height - max_bandwidth),
            0,
        )
        for i, s in enumerate(sigmas):
            radius = int(truncate * s + 0.5)
            width = 2 * radius + 1
            kernel = torch_gaussian_kernel(width=width, sigma=s.item())

            # this is to align all of the kernels so that the eventual fft shifts a fixed amount
            center_pad_size = self.max_radius - radius
            if center_pad_size > 0:
                centered_kernel = nn.ConstantPad2d(center_pad_size, 0)(kernel)
            else:
                centered_kernel = kernel

            padded_kernel = kernel_pad(centered_kernel)

            f_kernel = torch.rfft(
                padded_kernel, signal_ndim=self.signal_ndim, onesided=True
            )
            self.f_gaussian_pyramid.append(f_kernel)

        self.f_gaussian_pyramid = nn.Parameter(
            torch.stack(self.f_gaussian_pyramid, dim=0), requires_grad=False
        )

    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        img_height, img_width = list(input.size())[-self.signal_ndim:]
        assert (img_height, img_width) == (self.img_height, self.img_width)

        padded_input = self.pad_input(input)
        f_input = torch.rfft(padded_input, signal_ndim=self.signal_ndim, onesided=True)
        f_gaussian_images = comp_mul(self.f_gaussian_pyramid, f_input)
        gaussian_images = torch.irfft(
            f_gaussian_images,
            signal_ndim=self.signal_ndim,
            onesided=True,
            signal_sizes=padded_input.shape[1:],
        )

        # fft induces a shift so needs to be undone
        gaussian_images = gaussian_images[
                          :,  # batch dimension
                          :,  # filter dimension
                          self.max_radius: self.img_height + self.max_radius,
                          self.max_radius: self.img_width + self.max_radius,
                          ]

        return gaussian_images


def torch_gaussian_kernel(
        width: int = 21, sigma: int = 3, dim: int = 2
) -> torch.Tensor:
    """Gaussian kernel

    Parameters
    ----------
    width: bandwidth of the kernel
    sigma: std of the kernel
    dim: dimensions of the kernel (images -> 2)

    Returns
    -------
    kernel : gaussian kernel

    """

    if isinstance(width, numbers.Number):
        width = [width] * dim
    if isinstance(sigma, numbers.Number):
        sigma = [sigma] * dim
    kernel = 1
    meshgrids = torch.meshgrid(
        [torch.arange(size, dtype=torch.float32) for size in width]
    )
    for size, std, mgrid in zip(width, sigma, meshgrids):
        mean = (size - 1) / 2
        kernel *= (
                1
                / (std * math.sqrt(2 * math.pi))
                * torch.exp(-(((mgrid - mean) / std) ** 2) / 2)
        )

    # Make sure sum of values in gaussian kernel equals 1.
    kernel = kernel / torch.sum(kernel)
    return kernel


def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i: i + n]


def comp_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """Complex multiplies two complex 3d tensors

    x = (x_real, x_im)
    y = (y_real, y_im)
    x*y = (x_real*y_real - x_im*y_im, x_real*y_im + x_im*y_real)

    Last dimension is x2 with x[..., 0] real and x[..., 1] complex.
    Dimensions (-3,-2) must be equal of both a and b must be the same.

    Examples
    ________
    >>> f_filters = torch.rand((20, 1024, 1024, 2))
    >>> f_imgs = torch.rand((5, 1024, 1024, 2))
    >>> f_filtered_imgs = comp_mul(f_filters, f_imgs)

    Parameters
    ----------
    x : Last dimension is (a,b) of a+ib
    y : Last dimension is (a,b) of a+ib

    Returns
    -------
    z : x*y

    """

    # hadamard product of every filter against every batch image
    op = partial(contract, "fuv,buv->bfuv")
    assert x.shape[-1] == y.shape[-1] == 2
    x_real, x_im = x.unbind(-1)
    y_real, y_im = y.unbind(-1)
    z = torch.stack(
        [op(x_real, y_real) - op(x_im, y_im), op(x_real, y_im) + op(x_im, y_real)],
        dim=-1,
    )
    return z


def run(rank, size):
    with torch.no_grad():
        img_tensor = torch.rand((1, 1000, 1000))

        dog = DifferenceOfGaussiansFFT(
            img_height=1000,
            img_width=1000,
            sigma_bins=48 // size,
            max_sigma=30,
        ).to(rank, non_blocking=True)
        for p in dog.parameters():
            p.requires_grad = False
        dog.eval()
        torch.cuda.synchronize(rank)

        dogs = []
        start = time.monotonic()
        s = torch.cuda.current_stream(rank)
        e_start = torch.cuda.Event(enable_timing=True)
        e_finish = torch.cuda.Event(enable_timing=True)
        for i in range(10):
            img_tensor = img_tensor.to(rank)
            # torch.cuda.synchronize(rank)
            s.record_event(e_start)
            dogs.append(dog(img_tensor))
            s.record_event(e_finish)
            e_finish.synchronize()
            print(f"Iteration {i} forward latency is {e_start.elapsed_time(e_finish)}")
        end = time.monotonic()
        print("end - start = ", end - start)
        torch.cuda.synchronize(rank)
        return dogs


def init_process(rank_size_fn, backend="nccl"):
    rank, size, fn = rank_size_fn
    """ Initialize the distributed environment. """
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(backend, rank=rank, world_size=size)
    return fn(rank, size)


if __name__ == "__main__":
    set_start_method("spawn")

    size = 1
    print("====== size = 1  ======")
    pool = Pool(processes=size)
    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    #print(end - start)
    pool.close()

    print("====== size = 2  ======")

    size = 2
    pool = Pool(processes=size)
    start = time.monotonic()
    res = pool.map(init_process, [(i, size, run) for i in range(size)])
    end = time.monotonic()
    #print(end - start)
    pool.close()
    # print(res)

@mrshenli okay thanks I’ll try this. But also this means all of my other timings have been wrong. So thanks for showing me how to use cuda events too.

edit:

@mrshenli it turns out you need a synchronize after all

        start = time.monotonic()
        s = torch.cuda.current_stream(rank)
        e_start = torch.cuda.Event(enable_timing=True)
        e_finish = torch.cuda.Event(enable_timing=True)
        s.record_event(e_start)
        for i in range(10):
            img_tensor = img_tensor_cpu.to(rank)
            # torch.cuda.synchronize(rank)
            dogs.append(dog(img_tensor))
        torch.cuda.synchronize(rank)
        s.record_event(e_finish)
        e_finish.synchronize()
        end = time.monotonic()

gives me this

====== size = 1  ======

rank 0 Iteration 9 forward latency is 1283.0596923828125
end - start =  1.2832060940563679


====== size = 2  ======

rank 0 Iteration 9 forward latency is 626.5835571289062
end - start =  0.6267032357864082
rank 1 Iteration 9 forward latency is 640.3717041015625
end - start =  0.6404897100292146


====== size = 3  ======

rank 0 Iteration 9 forward latency is 443.1278076171875
end - start =  0.44322703895159066
rank 1 Iteration 9 forward latency is 471.8766174316406
end - start =  0.47198665188625455
rank 2 Iteration 9 forward latency is 461.29559326171875
end - start =  0.46140363393351436


====== size = 4  ======

rank 0 Iteration 9 forward latency is 397.9264221191406
end - start =  0.3981346560176462
rank 2 Iteration 9 forward latency is 374.9112243652344
end - start =  0.3749916541855782
rank 3 Iteration 9 forward latency is 360.9978942871094
end - start =  0.3610941809602082
rank 1 Iteration 9 forward latency is 362.57073974609375
end - start =  0.3626508240122348
1 Like

@mrshenli i got everything working. thanks for your help!