Do I still need the output after calculate the loss? [before the loss.backward()]

I am working on an attention based model with contrastive learning loss and trying to reduce GPU memory as much as possible to increase the batch size. My model contains two sub models and have some shared layers between the two parts. Both preform similar tasks but from different direction, Part A: A-> B, Part B: B->A. I got OOM at the loss.backward() step

To reduce the memory usage:

  1. Break the training into two parts
  2. Deleting unused variables and clear cache

The workflow is like:

Load model
     Enter batch training:
            Load data (x1, x2, x3, etc):
            # A->B
            out1, out2, out3, out4, etc = submodel_A (x1, x2, x3, etc)
            loss_AB = loss_fn(out1, out2, out3, out4)
            torch.cuda.empty_cache()
            with torch.no_grad():
                 metrics_AB = met_fn(out1, out2, out3, out4)
            del out1, out2, out3, out4
            torch.cuda.empty_cache()
            loss_AB.backward()
            optimizer.step()
            del loss_AB
            torch.cuda.empty_cache()

            #B->A
            out1, out2, out3, out4, etc = submodel_B (x1, x2, x3, etc)
            loss_BA = loss_fn(out1, out2, out3, out4)
            torch.cuda.empty_cache()
            with torch.no_grad():
                 metrics_BA = met_fn(out1, out2, out3, out4)
            del out1, out2, out3, out4
            torch.cuda.empty_cache()
            loss_BA.backward()
            optimizer.step()
            del loss_BA
            torch.cuda.empty_cache()
               

Questions:

  1. I am not sure del output before loss_BA.backward() is a correct action. Does the graph of loss already contain all the information it need for backward(), or it still need the output or the leaf nodes? I did some test to monitor a specific parameter in the model, that to do with/without the output delation. The .weight, .weight.grad and .weight after optimiser.step() looks identical. does this mean I don’t really need the output for backward()?

  2. I am also monitoring the memory usage at different point, apparently deleting the output does not reduce much of the memory. The loss cost most of the memory, is this because of the graph of loss is huge? below is the memory usage:

After set up model: 

GPU usage on cuda 0: used- 1357.0 MB, precentage- 8.397017419015501 %

Entering new batch: 

GPU usage on cuda 0: used- 1357.0 MB, precentage- 8.397017419015501 %

Getting input:

GPU usage on cuda 0: used- 1377.0 MB, precentage- 8.520775966090158 %

Enter A->B:
after vit 0: 
GPU usage on cuda 0: used- 7085.0 MB, precentage- 43.84146530119736 %

after vit 1: 
GPU usage on cuda 0: used- 12271.0 MB, precentage- 75.93205655765601 %

After getting output: 
GPU usage on cuda 0: used- 12271.0 MB, precentage- 75.93205655765601 %

After getting loss 
GPU usage on cuda 0: used- 12271.0 MB, precentage- 75.93205655765601 %

After calcuate the metrics 
GPU usage on cuda 0: used- 12271.0 MB, precentage- 75.93205655765601 %

After delete output 
GPU usage on cuda 0: used- 11971.0 MB, precentage- 74.07567835153615 %

After optimizer step 
GPU usage on cuda 0: used- 14853.0 MB, precentage- 91.90928498499427 %

After delete loss 
GPU usage on cuda 0: used- 1415.0 MB, precentage- 8.755917205532008 %

Entering B->A

  1. Im using with torch.no_grad(): when calculation the metrics, it is before loss.backward and after loss calculated. Will this affect the loss.backward?