Torch.nn.utils.clip_grad.clip_grad_norm_ is too slow

The forward process takes 0.00435s, the loss computation takes 0.00138s, bug the clip_grad_norm_ needs 9.3306293487s. It slows the training apparently.
The code is below:

clip_grad.clip_grad_norm_(filterd_params, **grad_clip_config)

Is there any method to speed the model training?

def clip_grad_norm_(parameters, max_norm, norm_type=2):
    r"""Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    import mmcv
    timer = mmcv.Timer()
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    print("#####", timer.since_start())
    if norm_type == inf:
        total_norm = max(p.grad.data.abs().max() for p in parameters)
    else:
        total_norm = 0
        print(timer.since_start())
        for p in parameters:
            param_norm = p.grad.data.norm(norm_type)
            print("1)params {} device={} for loop, takes".format(p.shape, p.device), timer.since_last_check())
            print("param_norm.item()={}, norm_type={}".format(param_norm.item(), norm_type),timer.since_last_check())
            total_norm += param_norm.item() ** norm_type
            print("2)params {} for loop, takes".format(p.shape), timer.since_last_check())
        total_norm = total_norm ** (1. / norm_type)
    clip_coef = max_norm / (total_norm + 1e-6)
    print("clip_coef={}".format(clip_coef), timer.since_last_check())
    if clip_coef < 1:
        for p in parameters:
            p.grad.data.mul_(clip_coef)
            print("forloop2 {} ".format(p.shape), timer.since_last_check())
    return total_norm
1)params torch.Size([96, 3, 11, 11]) device=cuda:1 for loop, takes 9.655952453613281e-05
param_norm.item()=1.9029390811920166, norm_type=2.0 9.31826901435852
2)params torch.Size([96, 3, 11, 11]) for loop, takes 0.0003921985626220703
1)params torch.Size([96]) device=cuda:1 for loop, takes 0.000354766845703125
param_norm.item()=0.001021179836243391, norm_type=2.0 4.887580871582031e-05
2)params torch.Size([96]) for loop, takes 4.1484832763671875e-05
1)params torch.Size([256, 48, 5, 5]) device=cuda:1 for loop, takes 0.0001227855682373047
param_norm.item()=3.1388771533966064, norm_type=2.0 4.410743713378906e-05
2)params torch.Size([256, 48, 5, 5]) for loop, takes 3.647804260253906e-05
1)params torch.Size([256]) device=cuda:1 for loop, takes 8.0108642578125e-05
param_norm.item()=0.001840393990278244, norm_type=2.0 5.340576171875e-05
2)params torch.Size([256]) for loop, takes 5.1021575927734375e-05

Then I find that the first weight.item() take much time.

There tf api tf.clip_by_global_norm is similar to [Torch.nn.utils.clip_grad.clip_grad_norm_ is too slow but it doesn’t make trainging slowerly.

  1. The reported times are likely inaccurate because you are not taking into account asynchronous cuda execution, so synchronizing .item() call appears as if it’s taking a long time, but in fact it is reporting the time taken by previous asynchronous calls
  2. in pytorch 1.5 computation of clip_grad_norm was changed to avoid synchronizing .item() calls.

I tried to use pytorch 1.5, but the training speed is as slow as before.

You are right. I should use torch.cuda.synchronize() before record the time point.