Compute the lipschitz of a deep network effectively

Hi, I want to compute the lipschitz constant of a network, but it seems slow. Here is how I implement it:

# naive version
lc = 1
for block in self.blocks:
    lc = lc * compute_lipschitz(block)

The function compute_lipschitz is like:

def compute_lipschitz(conv):
    x = torch.randn_like(input_of_conv)
    for i in range(5):
        x = F.conv(x, conv.weight, stride, padding, ...)
        x = F.conv_tranpose(x, conv.weight, stride, padding, ...)
        x = x / x.norm()
    return F.conv(x, ...).norm()

for convolution layers. For activation layers, it returns the maximum absolute derivative, e.g., for relu, it returns 1.

However, I find it very slow since 1) the network is deep, 2) the computation in compute_lipschitz is not heavy and cannot take full benefits of GPUs, 3) The computation has to be repeated on all GPUs if I use DDP, which I think is not very necessary.

I tried two ideas to solve this problem, but they did not work.

  1. Use two CUDA stream, one for the model forward, one for the compute_lipschitz, it seems not very fast.

  2. With DDP on 8 GPUs, one GPU only compute the lipschitz of 1/8 blocks:

# accelerate version
lc = torch.zeros(len(self.blocks)).to(gpu_device)
for idx, block in enumerate(self.blocks):
    if idx % num_gpus = rank:
        lc[idx] = compute_lipschitz(block)

torch.distributed.all_reduce(lc)
lc = lc.prod()

This method is 8 times faster and the result is correct as the naive version. However, the gradient seems not correct. The training of my network is very different as the naive version, though the value lc is computed correctly.

Any ideas to solve the gradient problem or accelerate the naive version? Thanks!

This is a very interesting problem :slight_smile:

Some initial clarifications:

  1. More precisely, is lc = torch.ones(1) in the naive version? (You mention that it requires gradient, so I assume it is a torch.Tensor.)
  2. Reading the docs for torch.nn.utils.spectral_norm, the function appears to modify the parameter in-place. Does compute_lipschitz do any in-place modification, or does it simply return a value? If it returns a value, does it return a Python float, a torch.Tensor, or something else?
  3. Is there any data dependency between the model forward and computing the model Lipschitz constant?

Also, my guess is that the gradient is not correct because the autograd history of lc = torch.zeros(len(self.blocks)) is not propagated through the all-reduce. Each rank only has the history for its fraction of the elements, so gradients will not flow the loss back through to lc as expected.

Finally, is there any way to write a batched compute_lipschitz(blocks) that can perform the prod in the kernel directly?

1 Like

It looks like if you use this all_reduce API, the autograd history will be propagated appropriately (by calling an additional all-reduce in the backward). Maybe try dropping this in instead of dist.all_reduce and see if it fixes your problem.

Hi @agu Thanks for your reply. I made some modification to the question, so it can be more clear.

  1. For lc = torch.ones(1) it is not in the naive version. In the naive version, I initialized it with 1, and muliply it with all blocks’ lipschitz.

  2. The compute_lipschitz does not do any in-place operations, just return a number with gradient if the computed module has parameters, e.g., conv layer. Or a number without gradient it the computed module has no parameters, e.g., ReLU layer .

  3. The computing of the model Lipschitz is independent of the model forward.

I will see if I can do a batched compute_lipschitz for layers whose parameters’ shape are the same.
I wonder what is the difference between torch.distributed.all_reduce and torch.distributed.nn.functional.all_reduce? Thanks!

I wonder what is the difference between torch.distributed.all_reduce and torch.distributed.nn.functional.all_reduce ? Thanks!

Sorry, if I was unclear. The difference is that the nn.functional.all_reduce one should propagate the gradients from all ranks, while the normal distributed.all_reduce does not. In my understanding, if you have a tensor t = torch.zeros(world_size), run t[rank] = <some computation tracked by autograd> on each rank, call nn.functional.all_reduce(t) on each rank, and use the resulting t for loss computation, then calling loss.backward() will provide gradients for all world_size elements in t, not only just t[rank]. This should be because the nn.functional.all_reduce will additionally all-reduce each rank’s gradient.

Regarding your answers to 1 and 2, my concern is that I do not think that gradients can propagate from scalars (i.e. Python float) back through to torch.Tensors. Either way, I still suggest trying out nn.functional.all_reduce first and see if that works.

@agu After using torch.distributed.nn.functional.all_reduce , the program stuck at the first iteration :upside_down_face:

:frowning:

Would it be possible to share a minimal version of your script? I can try to take a look.

Sure I will make it today.

@agu Sorry for the late reply. I was busy yesterday. Here is a minimal version. I removed some complicated loss so that the main idea is clear.

The problem is solved by batching as you suggested. The runtime of my model is reduced by 40%. But it is still interesting to see why all_reduce not works here.

import argparse
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F


class conv3x3(nn.Conv2d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 n_power_iterations=10,
                 eps=1e-12,
                 input_size=32) -> None:
        super(conv3x3, self).__init__(in_channels=in_channels,
                                      out_channels=out_channels,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1,
                                      bias=True)
        nn.init.orthogonal_(self.weight)

        self.num = n_power_iterations
        self.eps = eps

        init_x = torch.randn(1, self.in_channels, input_size, input_size)
        self.register_buffer('init_x', init_x)

    def lipschitz(self):
        x = self.init_x.data

        # Power Method to find the lipschitz
        for _ in range(self.num_iter):
            x = F.conv2d(x, self.weight, bias=None, padding=1)
            x = F.conv_transpose2d(x, self.weight, bias=None, padding=1)
            x = x / (x.norm() + self.eps)

        self.init_x += (x - self.init_x).detach()
        x = F.conv2d(x,
                     self.weight,
                     bias=None,
                     stride=self.stride,
                     padding=self.padding)
        return x.norm()


class ReLU(nn.ReLU):
    def __init__(self):
        super(ReLU, self).__init__()

    def lipschitz(self):
        return 1.


class ToyModel(nn.Module):
    def __init__(self, depth):
        layers = []
        for _ in range(depth):
            layers.append(conv3x3(64, 64))
            layers.append(ReLU())

        self.layers = layers

    def forward(self, x):
        out = self.layers(x)
        lc = self.lipschitz()
        return out, lc

    def lipschitz(self):
        if dist.is_initialized():
            return self.fast_lipschitz()
        lc = 1
        for module in self.layers:
            lc = lc * module.lipschitz()
        return lc

    def fast_lipschitz(self):
        lc_list = torch.zeros(len(self.layers))
        lc_list = lc_list.to(self.layers[0].weight.device)

        rank = dist.get_rank()
        num_gpus = dist.get_world_size()

        for idx, module in enumerate(self.layers):
            if idx % num_gpus == rank:
                lc_list[idx] = module.lipschitz()

        dist.all_reduce(lc_list)
        return lc_list.prod()


if __name__ == '__main__':
    dist.init_process_group(backend="nccl")
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int)
    args = parser.parse_args()

    local_rank = args.local_rank
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    model = ToyModel(20).to(device)
    model = nn.parallel.DistributedDataParallel(model)

    inputs = torch.randn(8, 64, 32, 32).to(device)
    targets = torch.randn(8, 64, 32, 32).to(device)

    optimizer = torch.optim.Adam(model.parameters())

    for _ in range(100):
        optimizer.zero_step()
        outputs, lc = model(inputs)
        loss = (outputs - targets).pow(2).mean() + lc * 0.1
        loss.backward()
        optimizer.step()

Today I tried this and it is really fast. The lipschtiz constants of layers have the same shape of weights can be computed together. Say my residual network has 3 stages, and in each stage the convolutional layers’ shapes are the same. I just concatenate the weights of these layers (as a single group convolution) and do the computation together. Although the computation cost is not changed, the batching idea is much faster.

1 Like