Network training changes with different pytorch version

Hi, I am trying to train a deep model that is consisted of ResNet50 and GAN.

In torch==1.2.0, the model trains well and the performance is very good.

But I am trying to use torch.cuda.amp() which is only available in torch version higher than torch>=1.6.0.

So I tried to train the model with torch==1.6.0 and as a result, the model performance changes too badly. (I did not use torch.cuda.amp yet.)

What actually could be the problem when only changing the torch version at the same circumstance(Same data, same layer, same loss…)

Does the layers or operation behavior change when torch version change?

Could you share some more details of the setup (e.g., GPU, precision, etc.)? If the loss is the same then it smells like a difference is appearing somewhere during the evaluation.

I used the same setup(TITAN RTX GPU and linux etc), only changed torch version.
Python3.6 with pytorch version 1.2 → python3.6 with pytorch version 1.6.

The dramatic change occurs when I update the weights of the network during training.

For example, during training, the first iteration prediction and loss are the same for both versions but when I use optimizer.step(), second iteration prediction, and loss are so different.

So in PyTorch 1.6 version, eventually results in so bad performance.

I looked at the documentation of the optimizer.step(), but it seems there is no difference between pytorch1.2 and pytorch1.6 versions.

This is the code for my optimizers that trains the network.

You could directly compare the differences in adam.py and would see that the bias_correction seems to be one significant change potentially causing the different behavior:

git diff v1.2.0 v1.6.0 -- torch/optim/adam.py
diff --git a/torch/optim/adam.py b/torch/optim/adam.py
index edcfcc26be..9d68613c64 100644
--- a/torch/optim/adam.py
+++ b/torch/optim/adam.py
@@ -37,6 +37,8 @@ class Adam(Optimizer):
             raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
         if not 0.0 <= betas[1] < 1.0:
             raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        if not 0.0 <= weight_decay:
+            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
         defaults = dict(lr=lr, betas=betas, eps=eps,
                         weight_decay=weight_decay, amsgrad=amsgrad)
         super(Adam, self).__init__(params, defaults)
@@ -46,6 +48,7 @@ class Adam(Optimizer):
         for group in self.param_groups:
             group.setdefault('amsgrad', False)
 
+    @torch.no_grad()
     def step(self, closure=None):
         """Performs a single optimization step.
 
@@ -55,13 +58,14 @@ class Adam(Optimizer):
         """
         loss = None
         if closure is not None:
-            loss = closure()
+            with torch.enable_grad():
+                loss = closure()
 
         for group in self.param_groups:
             for p in group['params']:
                 if p.grad is None:
                     continue
-                grad = p.grad.data
+                grad = p.grad
                 if grad.is_sparse:
                     raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                 amsgrad = group['amsgrad']
@@ -72,12 +76,12 @@ class Adam(Optimizer):
                 if len(state) == 0:
                     state['step'] = 0
                     # Exponential moving average of gradient values
-                    state['exp_avg'] = torch.zeros_like(p.data)
+                    state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                     # Exponential moving average of squared gradient values
-                    state['exp_avg_sq'] = torch.zeros_like(p.data)
+                    state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
                     if amsgrad:
                         # Maintains max of all exp. moving avg. of sq. grad. values
-                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)
+                        state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
 
                 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                 if amsgrad:
@@ -85,25 +89,25 @@ class Adam(Optimizer):
                 beta1, beta2 = group['betas']
 
                 state['step'] += 1
+                bias_correction1 = 1 - beta1 ** state['step']
+                bias_correction2 = 1 - beta2 ** state['step']
 
                 if group['weight_decay'] != 0:
-                    grad.add_(group['weight_decay'], p.data)
+                    grad = grad.add(p, alpha=group['weight_decay'])
 
                 # Decay the first and second moment running average coefficient
-                exp_avg.mul_(beta1).add_(1 - beta1, grad)
-                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                 if amsgrad:
                     # Maintains the maximum of all 2nd moment running avg. till now
                     torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                     # Use the max. for normalizing running avg. of gradient
-                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+                    denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
                 else:
-                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+                    denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
 
-                bias_correction1 = 1 - beta1 ** state['step']
-                bias_correction2 = 1 - beta2 ** state['step']
-                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
+                step_size = group['lr'] / bias_correction1
 
-                p.data.addcdiv_(-step_size, exp_avg, denom)
+                p.addcdiv_(exp_avg, denom, value=-step_size)
 
         return loss

As a test copy-paste the v1.2.0 adam.py file and train your model with it to see if these changes are indeed causing the issue.

1 Like

Changing version 1.6 adam.py with version 1.2 adam.py copy paste did not work for me.

I also found some other differences during forward part.
Here in this code, the l_pos.grad_fn is different with two versions.

torch==1.2.0 → l_pos == <AsStridedBackward object at 0x7f0ed25cc990>

torch==1.6.0 → l_pos == <ViewBackward object at 0x7fa1a422f2d0>

Could this some kind of grad_fn change affect the backpropagation of the same code?

No, I don’t think the reshape/view operation is related to your issue as we would have seen a lot of issues since the change.

image

I found the issue!.

I was using the pretrained model, so I was loading the weights and optimizers for Adam.

I don’t know why but when I removed the code and did not load the optimizers in the adam optimizer, the training was okay.

Thank you for your comment.

1 Like