How to efficiently save the logits output of a model

Hi, I am trying to save the logits output of a model, but my code cost super large GPU memory and finally gets the OOM exception. I tried to print the memory cost of each logit. The result shows its memory cost is small. I don’t know how to further debug this issue. Please let me know how to efficiently save the logits.

the code is:

        net.eval()
        for batch_idx, (images, labels) in enumerate(self.local_training_data):
            images, labels = images.to(self.device), labels.to(self.device)
            log_probs, extracted_features = net(images)
            #self.extracted_feature_dict[batch_idx] = extracted_features.cpu()
            logging.info("shape = " + str(log_probs.shape))
            logging.info("element size = " + str(log_probs.element_size()))
            logging.info("nelement = " + str(log_probs.nelement()))
            logging.info("GPU memory1 = " + str(log_probs.nelement() * log_probs.element_size()))
            self.logits_dict[batch_idx] = log_probs.cpu()

the model definition is:

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)  # B x 16 x 32 x 32
        extracted_features = x

        x = self.layer1(x)  # B x 16 x 32 x 32

        # output here
        x = self.avgpool(x)  # B x 16 x 1 x 1
        x_f = x.view(x.size(0), -1)  # B x 16
        logits = self.fc(x_f)  # B x num_classes
        return logits, extracted_features

the log shows:

INFO:root:shape = torch.Size([64, 10])
INFO:root:element size = 4
INFO:root:nelement = 640
INFO:root:GPU memory1 = 2560

the exception is:

  File "/home/chaoyanghe/sourcecode/FederatedLearning/group_knowledge_transfer_single/models/cifar/model_client.py", line 55, in forward
    out = self.bn1(out)
  File "/home/chaoyanghe/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/chaoyanghe/anaconda3/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 107, in forward
    exponential_average_factor, self.eps)
  File "/home/chaoyanghe/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1670, in batch_norm
    training, momentum, eps, torch.backends.cudnn.enabled
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.76 GiB total capacity; 9.93 GiB already allocated; 18.19 MiB free; 9.96 GiB reserved in total by PyTorch)

Hardware environment: RTX 2080Ti
Software environment: torch 1.4.0, python 3.7.4

1 Like