The forward process takes 0.00435s, the loss computation takes 0.00138s, bug the clip_grad_norm_ needs 9.3306293487s. It slows the training apparently.
The code is below:
clip_grad.clip_grad_norm_(filterd_params, **grad_clip_config)
Is there any method to speed the model training?
def clip_grad_norm_(parameters, max_norm, norm_type=2):
r"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
import mmcv
timer = mmcv.Timer()
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
print("#####", timer.since_start())
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
total_norm = 0
print(timer.since_start())
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
print("1)params {} device={} for loop, takes".format(p.shape, p.device), timer.since_last_check())
print("param_norm.item()={}, norm_type={}".format(param_norm.item(), norm_type),timer.since_last_check())
total_norm += param_norm.item() ** norm_type
print("2)params {} for loop, takes".format(p.shape), timer.since_last_check())
total_norm = total_norm ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
print("clip_coef={}".format(clip_coef), timer.since_last_check())
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
print("forloop2 {} ".format(p.shape), timer.since_last_check())
return total_norm
1)params torch.Size([96, 3, 11, 11]) device=cuda:1 for loop, takes 9.655952453613281e-05
param_norm.item()=1.9029390811920166, norm_type=2.0 9.31826901435852
2)params torch.Size([96, 3, 11, 11]) for loop, takes 0.0003921985626220703
1)params torch.Size([96]) device=cuda:1 for loop, takes 0.000354766845703125
param_norm.item()=0.001021179836243391, norm_type=2.0 4.887580871582031e-05
2)params torch.Size([96]) for loop, takes 4.1484832763671875e-05
1)params torch.Size([256, 48, 5, 5]) device=cuda:1 for loop, takes 0.0001227855682373047
param_norm.item()=3.1388771533966064, norm_type=2.0 4.410743713378906e-05
2)params torch.Size([256, 48, 5, 5]) for loop, takes 3.647804260253906e-05
1)params torch.Size([256]) device=cuda:1 for loop, takes 8.0108642578125e-05
param_norm.item()=0.001840393990278244, norm_type=2.0 5.340576171875e-05
2)params torch.Size([256]) for loop, takes 5.1021575927734375e-05
Then I find that the first weight.item() take much time.