Mixed precision training requiring more memory than normal

I am using mixed precision training with a scaler and autocast(), my main reason to use fp16 is so I can fit a bigger model onto my gpu.

However it seems like the mixed precision training requires more GPU memory than without, as my code is now running out of GPU memory on the backward call where it could train before…

scaler.scale(loss).backward()

I am assuming that this is because of the copying of the fp32 weights when converting to fp16 in the backward call.

I followed https://pytorch.org/docs/stable/amp.html and the other resources available but haven’t been able to find a solution to this, ideally I would like to have the copied fp32 weights on the cpu, especially if I am not using them at that point.

If I lower my (already small) batchsize from 4 to 2, it does work, however I wanted to use mixed precision training to increase my batch size/model size!

Any help would be much appreciated!

The native mixed-precision implementation doesn’t use master parameters and thus will not create copies.
What kind of model are you using and could you post the model definition with all arguments and random input shapes here, please?

Thanks for the help!

The native mixed-precision implementation doesn’t use master parameters and thus will not create copies.
Okay… I assume that using mixed-precision shouldn’t increase my memory usage on my GPU right, at the best decrease it?
I haven’t been able to find any indication in terms of memory improvements/changes that one can expect?

I am using a Siamese tracker with a ResNet-50 backbone in combination with a co-attention transformer from mcan-vqa.

Strangely enough when I tried it today it does work with the same batch size and mixed precision training. Maybe I still had some allocated GPU memory…

The model is a bit complicated so I don’t know of how much use this will be:
The input shapes are the following, plus a sentence that is converted using word embeddings:

name=template   shape=torch.Size([4, 3, 127, 127])
name=search   shape=torch.Size([4, 3, 255, 255])

Below is the model printed out + a memory summary at the end from cuda.memory_summary():

DistModule(
  (module): ModelBuilder(
    (backbone): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer2): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer3): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (3): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (4): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (5): Bottleneck(
          (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (layer4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (downsample): Sequential(
            (0): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2), dilation=(2, 2), bias=False)
            (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): Bottleneck(
          (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(4, 4), dilation=(4, 4), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (neck): AdjustAllLayer(
      (downsample2): AdjustLayer(
        (downsample): Sequential(
          (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (downsample3): AdjustLayer(
        (downsample): Sequential(
          (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (downsample4): AdjustLayer(
        (downsample): Sequential(
          (0): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (rpn_head): MultiRPN(
      (rpn2): DepthwiseRPN(
        (cls): DepthwiseXCorr(
          (conv_kernel): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (conv_search): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (head): Sequential(
            (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(256, 10, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (loc): DepthwiseXCorr(
          (conv_kernel): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (conv_search): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (head): Sequential(
            (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(256, 20, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
      (rpn3): DepthwiseRPN(
        (cls): DepthwiseXCorr(
          (conv_kernel): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (conv_search): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (head): Sequential(
            (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(256, 10, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (loc): DepthwiseXCorr(
          (conv_kernel): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (conv_search): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (head): Sequential(
            (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(256, 20, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
      (rpn4): DepthwiseRPN(
        (cls): DepthwiseXCorr(
          (conv_kernel): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (conv_search): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (head): Sequential(
            (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(256, 10, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (loc): DepthwiseXCorr(
          (conv_kernel): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (conv_search): Sequential(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (head): Sequential(
            (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(256, 20, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
    )
    (nlp_embedding): Embedding(
      (w2v): Word2VecEmbedding()
      (conv1d): Conv1d(300, 300, kernel_size=(2,), stride=(1,))
      (activation): ReLU()
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (constraint_classifier): LinearLayer(
      (head): Sequential(
        (0): Flatten()
        (1): Linear(in_features=1024, out_features=1, bias=True)
      )
      (activation): Sigmoid()
    )
    (constraint_loss_fn): BCEWithLogitsLoss()
    ....
)

The rest of the model (it didn’t fit in my last reply), and the memory summary:

    (deep_attention): MCA_ED(
      (enc_list): ModuleList(
        (0): SA(
          (mhatt): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
        )
        (1): SA(
          (mhatt): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
        )
        (2): SA(
          (mhatt): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
        )
        (3): SA(
          (mhatt): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
        )
        (4): SA(
          (mhatt): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
        )
        (5): SA(
          (mhatt): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
        )
      )
      (dec_list): ModuleList(
        (0): SGA(
          (mhatt1): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (mhatt2): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
          (dropout3): Dropout(p=0.1, inplace=False)
          (norm3): LayerNorm()
        )
        (1): SGA(
          (mhatt1): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (mhatt2): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
          (dropout3): Dropout(p=0.1, inplace=False)
          (norm3): LayerNorm()
        )
        (2): SGA(
          (mhatt1): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (mhatt2): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
          (dropout3): Dropout(p=0.1, inplace=False)
          (norm3): LayerNorm()
        )
        (3): SGA(
          (mhatt1): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (mhatt2): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
          (dropout3): Dropout(p=0.1, inplace=False)
          (norm3): LayerNorm()
        )
        (4): SGA(
          (mhatt1): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (mhatt2): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
          (dropout3): Dropout(p=0.1, inplace=False)
          (norm3): LayerNorm()
        )
        (5): SGA(
          (mhatt1): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (mhatt2): MHAtt(
            (linear_v): Linear(in_features=256, out_features=256, bias=True)
            (linear_k): Linear(in_features=256, out_features=256, bias=True)
            (linear_q): Linear(in_features=256, out_features=256, bias=True)
            (linear_merge): Linear(in_features=256, out_features=256, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (ffn): FFN(
            (mlp): MLP(
              (fc): FC(
                (linear): Linear(in_features=256, out_features=512, bias=True)
                (relu): ReLU(inplace=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (linear): Linear(in_features=512, out_features=256, bias=True)
            )
          )
          (dropout1): Dropout(p=0.1, inplace=False)
          (norm1): LayerNorm()
          (dropout2): Dropout(p=0.1, inplace=False)
          (norm2): LayerNorm()
          (dropout3): Dropout(p=0.1, inplace=False)
          (norm3): LayerNorm()
        )
      )
    )
    (embedding_expansion): Linear(in_features=150, out_features=256, bias=True)
    (embedding_reduction): Linear(in_features=768, out_features=256, bias=True)
    (x_reshape): Sequential(
      (0): Flatten()
      (1): Linear(in_features=2560, out_features=1024, bias=True)
    )
    (y_reshape): Sequential(
      (0): Flatten()
      (1): Linear(in_features=246016, out_features=1024, bias=True)
    )
  )
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |    3251 MB |    5549 MB |  344050 MB |  340799 MB |
|       from large pool |    3134 MB |    5387 MB |  333656 MB |  330522 MB |
|       from small pool |     117 MB |     167 MB |   10394 MB |   10277 MB |
|---------------------------------------------------------------------------|
| Active memory         |    3251 MB |    5549 MB |  344050 MB |  340799 MB |
|       from large pool |    3134 MB |    5387 MB |  333656 MB |  330522 MB |
|       from small pool |     117 MB |     167 MB |   10394 MB |   10277 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    6350 MB |    6350 MB |    6350 MB |       0 B  |
|       from large pool |    6182 MB |    6182 MB |    6182 MB |       0 B  |
|       from small pool |     168 MB |     168 MB |     168 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |     960 MB |    1056 MB |  259177 MB |  258216 MB |
|       from large pool |     941 MB |    1035 MB |  248215 MB |  247273 MB |
|       from small pool |      18 MB |      30 MB |   10961 MB |   10942 MB |
|---------------------------------------------------------------------------|
| Allocations           |    1330    |    2045    |  167490    |  166160    |
|       from large pool |      50    |     267    |   49870    |   49820    |
|       from small pool |    1280    |    1779    |  117620    |  116340    |
|---------------------------------------------------------------------------|
| Active allocs         |    1330    |    2045    |  167490    |  166160    |
|       from large pool |      50    |     267    |   49870    |   49820    |
|       from small pool |    1280    |    1779    |  117620    |  116340    |
|---------------------------------------------------------------------------|
| GPU reserved segments |     107    |     107    |     107    |       0    |
|       from large pool |      23    |      23    |      23    |       0    |
|       from small pool |      84    |      84    |      84    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |     108    |     122    |   80164    |   80056    |
|       from large pool |      10    |      25    |   26731    |   26721    |
|       from small pool |      98    |     111    |   53433    |   53335    |
|===========================================================================|

Yes, it shouldn’t use more memory, so maybe some processes were still using some device memory.
Anyway, let us know, if you encounter any other issues.

I have some additional questions that I did not yet find the answers to:

Should I be making edits in my model such that the batchnorm layers are run in FP32 mode? Or is this done automatically? (as I understood they might underflow/overflow otherwise)

How does it work when using your model in mixed precision for testing/inference, is it normal to also do this in mixed precision mode (could it negatively affect testing performance)?

Thanks!

You shouldn’t change any layers, as this will be done automatically for you.
If you encounter any instabilities, please post the layer and ping me so that we can have a look at it.

During inference you could also use mixed-precision training and it should not affect the testing performance (same as during training).

Hi @ptrblck! Back again with another question… :slight_smile:
Not sure if I should’ve opened a new thread, so here we go:

I have this piece of code (a attention calculation) that I am trying to fix while using mixed precision (for the model):

    def att(self, value, key, query, mask):
        d_k = query.size(-1)

        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            # print(scores.shape, mask.shape, torch.finfo(scores.dtype).min)
            scores = scores.masked_fill(mask, torch.finfo(scores.dtype).min)

        att_map = F.softmax(scores, dim=-1)
        att_map = self.dropout(att_map)

        return torch.matmul(att_map, value)

This now works for me when using non-mixed-precision, however it does not when using mixed-precision. This had to do with the softmax function (and the mask not being able to use sufficiently low numbers with fp16, I think) that gives a overflow error and my gradients and loss become NaNs.

I tried to run this part of the model with autocast(enabled=False) also trying in combination with things like @custom_fwd(cast_inputs=torch.float32) but haven’t been able to get it to work. If I try these options I get a cudNN error about not finding a kernel to perform convolutions with (in other parts of the network), or I kept getting the under/overflow errors.

torch.Size([32, 6, 75, 20]) torch.Size([32, 1, 1, 20]) are the shapes of scores and mask respectively

Otherwise the mixed precision was working well for me now, as I had to lower the batch_size when turning it off because of memory requirements

Thanks.

Could you post or upload some values for scores and mask, which produce the NaNs?
I cannot reproduce it locally using random inputs.

I uploaded the two tensors:
Scores: https://easyupload.io/6j6lvd
Masks: https://easyupload.io/vixfm9

This is the warning I am getting: WARNING:root:NaN or Inf found in input tensor.

Thanks for the help!

Thanks for the tenosrs.
The uploaded scores already contain NaN values in rows [9, 24, 26], so that the softmax output will also contain NaN values.
Also, the mask contains False for all entries.
Could you post the input tensors and code snippet, which produces these NaNs or are you unsure where these are created?

I fixed the issue, thanks @ptrblck.

the mask was created using the following function:
(the rest of the code is quite big…)

    def make_mask(self, feature):
        return (torch.sum(
            torch.abs(feature),
            dim=-1
        ) == 0).unsqueeze(1).unsqueeze(2)

I now created the mask using not just the features, but from a zeros/ones ‘mask’ that I had created earlier when creating the sentence embedding matrix (and then putting that through this function).
Some of the embedding features should have been zero but were not probably because of in-between ops, that still messed up the softmax.
Hope that is enough info :slight_smile: .