AMP convergence issues?

Hi all,

I seems to experience convergence issues when amp is activated, here are the logs:

Without AMP

training examples: 2781
image instances over all epoches: 25000
average utilisation per image over all epoches: 8.99

epoch 0/200: tloss:1.601 tacc:0.216 vloss:1.575 vacc:0.336
epoch 10/200: tloss:1.269 tacc:0.768 vloss:1.377 vacc:0.600
epoch 20/200: tloss:1.220 tacc:0.752 vloss:1.318 vacc:0.664
epoch 30/200: tloss:1.146 tacc:0.840 vloss:1.251 vacc:0.712
epoch 40/200: tloss:1.148 tacc:0.784 vloss:1.274 vacc:0.656
epoch 50/200: tloss:1.100 tacc:0.912 vloss:1.242 vacc:0.792
epoch 60/200: tloss:1.095 tacc:0.864 vloss:1.253 vacc:0.776
epoch 70/200: tloss:1.069 tacc:0.904 vloss:1.167 vacc:0.880
epoch 80/200: tloss:1.062 tacc:0.904 vloss:1.189 vacc:0.792
epoch 90/200: tloss:1.042 tacc:0.936 vloss:1.164 vacc:0.888
epoch 100/200: tloss:1.023 tacc:0.936 vloss:1.148 vacc:0.896
epoch 110/200: tloss:1.038 tacc:0.952 vloss:1.226 vacc:0.776
epoch 120/200: tloss:0.984 tacc:0.992 vloss:1.130 vacc:0.856
epoch 130/200: tloss:0.981 tacc:0.984 vloss:1.134 vacc:0.904
epoch 140/200: tloss:0.991 tacc:0.976 vloss:1.134 vacc:0.864
epoch 150/200: tloss:0.975 tacc:0.976 vloss:1.083 vacc:0.928
epoch 160/200: tloss:0.973 tacc:0.976 vloss:1.082 vacc:0.936
epoch 170/200: tloss:0.969 tacc:0.992 vloss:1.111 vacc:0.920
epoch 180/200: tloss:0.974 tacc:0.968 vloss:1.058 vacc:0.928
epoch 190/200: tloss:0.968 tacc:0.976 vloss:1.106 vacc:0.888
epoch 199/200: tloss:0.986 tacc:0.960 vloss:1.129 vacc:0.840

With AMP

epoch 0/200: tloss:1.601 tacc:0.272 vloss:1.603 vacc:0.216
epoch 10/200: tloss:1.530 tacc:0.456 vloss:1.540 vacc:0.464
epoch 20/200: tloss:1.429 tacc:0.608 vloss:1.500 vacc:0.472
epoch 30/200: tloss:1.334 tacc:0.680 vloss:1.454 vacc:0.496
epoch 40/200: tloss:1.295 tacc:0.728 vloss:1.406 vacc:0.648
epoch 50/200: tloss:1.284 tacc:0.720 vloss:1.400 vacc:0.624
epoch 60/200: tloss:1.272 tacc:0.712 vloss:1.364 vacc:0.656
epoch 70/200: tloss:1.231 tacc:0.768 vloss:1.364 vacc:0.688
epoch 80/200: tloss:1.222 tacc:0.808 vloss:1.323 vacc:0.688
epoch 90/200: tloss:1.223 tacc:0.776 vloss:1.325 vacc:0.744
epoch 100/200: tloss:1.214 tacc:0.776 vloss:1.368 vacc:0.584
epoch 110/200: tloss:1.188 tacc:0.824 vloss:1.277 vacc:0.728
epoch 120/200: tloss:1.169 tacc:0.824 vloss:1.289 vacc:0.736
epoch 130/200: tloss:1.190 tacc:0.784 vloss:1.286 vacc:0.672
epoch 140/200: tloss:1.183 tacc:0.792 vloss:1.269 vacc:0.760
epoch 150/200: tloss:1.170 tacc:0.856 vloss:1.270 vacc:0.728
epoch 160/200: tloss:1.170 tacc:0.792 vloss:1.296 vacc:0.680
epoch 170/200: tloss:1.143 tacc:0.832 vloss:1.287 vacc:0.736
epoch 180/200: tloss:1.143 tacc:0.864 vloss:1.273 vacc:0.736
epoch 190/200: tloss:1.131 tacc:0.880 vloss:1.275 vacc:0.752
epoch 199/200: tloss:1.161 tacc:0.864 vloss:1.243 vacc:0.800

The networks is a matching network with the following high level flow

  • Embedding net: with a set of conv layers
  • LSTM: for full context embeddings
  • Similarity computation and loss calculation

The differences are:
LSTM forward pass

    def forward(self, inputs):
        inputs = inputs.float() # CHANGE: LSTM as it complains with CUDNN_STATUS_BAD_PARAM without this
        self.hidden = self.repackage_hidden(self.hidden)
        output, self.hidden = self.lstm(inputs, self.hidden)
        return output

Main training loop

Outside of training lopp

scaler = torch.cuda.amp.GradScaler() 

Inside training loop

    with torch.cuda.amp.autocast(): # CHANGE: Added amp autocast
        match_net = matching.MatchingNetwork(
            keep_prob=constants.keep_prob,
            batch_size=args.batch_size,
            num_channels=constants.num_channels,
            learning_rate=constants.lr,
            fce=constants.fce,
            image_size=args.image_dim,
            device=device,
        )

....
        scaler.scale(c_loss).backward() # on loss backprop
        scaler.step(optimiser)
        scaler.update()
....

Any help / clarification would be super helpful!

I’m not sure if this is the reason but PyTorch recommends putting backward computation out of the autocast context.

Here’s a quote from the pytorch documentation.

autocast should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.

In your case, it would be

    with torch.cuda.amp.autocast(): # CHANGE: Added amp autocast
        match_net = matching.MatchingNetwork(
            keep_prob=constants.keep_prob,
            batch_size=args.batch_size,
            num_channels=constants.num_channels,
            learning_rate=constants.lr,
            fce=constants.fce,
            image_size=args.image_dim,
            device=device,
        )

    scaler.scale(c_loss).backward() # on loss backprop
    scaler.step(optimiser)
    scaler.update()

Ah totally missed this!
It does seem to help, thank you!