Guidance on why model uses huge amount of memory & time!

Dear Community,

My Model-Agnostic Meta-Learning model (on CIFAR dataset with shape[N, 3, 32, 32] inputs and not-one-hot labels) with 10000 iterations is taking unreasonable :sob: amount of resources (24 hrs and 40GB!) to run, and I can’t put my finger on what’s wrong. Could you please give me some pointers on what to check in my debugging efforts?

I’m also wondering if there’s a better way to tell a model messing up on resource use (compared to running a ton of iterations only for it to fail–perhaps some calculations to estimate the ‘expected & reasonable’ amount of resource use beforehand as a baseline?)

I’ve read the documentation here Performance Tuning Guide — PyTorch Tutorials 2.1.0+cu121 documentation and, I think I might have messed up on something more simple–I trained not-backward()-twice models on the same GPU with more parameters, and they run much faster.

My code is: (All I did was taking out the save() and write() methods, instead returning all the losses as a very long list so I can plot them–but that wouldn’t take 40GB, right? :thinking:)

LOG_INTERVAL = 10
VAL_INTERVAL = LOG_INTERVAL * 5
NUM_TEST_TASKS = 600

class Network_with_second_back_propagation():
    """
    Builds and Calculates network with second back-propagation
    -
    Functions:
        1 ```build_network()``` : initialize parameters\n
            the parameter is maintained (updated) OUTSIDE this object!
        2 ```forward_propagation()``` : calculates incoming data with parameters\n"""
    def __init__(self, device : str ='cuda' if torch.cuda.is_available() else 'cpu'):
        self._device = device
    def build_network(self,) -> dict[str, torch.Tensor]:
        """
        Makes dictionary of parameters for manual calculation
            This dictionary is maintained OUTSIDE this object
        Returns: 
            1 ```parameter``` (dict[str, torch.Tensor]): the network's parameter"""
        raise NotImplementedError()
    def forward_propagation(self, incoming : torch.Tensor, parameter : dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Inputs:
            1 ```incoming``` (torch.Tensor): input to the network 
                (naming collision with Python's 'input').
                shape (Batch, input size)
            2 ```parameter``` (dict[str, torch.Tensor]): parameter to calculate ```incoming```
                follows SAME form as build_network()'s return
        Return:
            1 ```result``` (torch.Tensor) calculated by the network 
                shape (Batch, result size)"""
        raise NotImplementedError()
    
class Network_convolution_then_linear(Network_with_second_back_propagation):
    def __init__(self, output_dimension : int, number_of_convolution_layer : int,
            number_of_input_channel : int, number_of_hidden_channel : int, convolution_kernel_size : int,
            ):
        super().__init__()
        self._output_dimension = output_dimension
        self._number_of_convolution_layer = number_of_convolution_layer
        self._number_of_input_channel = number_of_input_channel
        self._number_of_hidden_channel = number_of_hidden_channel
        self._convolution_kernel_size = convolution_kernel_size
    def build_network(self,) -> dict[str, torch.Tensor]:
        parameter : dict[str, torch.Tensor] = {}
        # build convolutional layer (feature extractor)
        in_channels = self._number_of_input_channel
        for layer in range(self._number_of_convolution_layer):
            parameter[f'convolution_weight{layer}'] = nn.init.xavier_uniform_(
                torch.empty(
                    self._number_of_hidden_channel,
                    in_channels,
                    self._convolution_kernel_size,
                    self._convolution_kernel_size,
                    requires_grad=True,
                    device=self._device))
            parameter[f'convolution_bias{layer}'] = nn.init.zeros_(
                torch.empty(
                    self._number_of_hidden_channel,
                    requires_grad=True,
                    device=self._device))
            in_channels = self._number_of_hidden_channel
        # build linear layer (linear head)
        parameter[f'linear_weight'] = nn.init.xavier_uniform_(
            torch.empty(
                self._output_dimension,
                self._number_of_hidden_channel,
                requires_grad=True,
                device=self._device))
        parameter[f'linear_bias'] = nn.init.zeros_(
            torch.empty(
                self._output_dimension,
                requires_grad=True,
                device=self._device))
        return parameter
    def forward_propagation(self, incoming, parameter):
        x = incoming
        for layer in range(self._number_of_convolution_layer):
            x = F.conv2d(
                input=x,
                weight=parameter[f'convolution_weight{layer}'],
                bias=parameter[f'convolution_bias{layer}'],
                stride=1,
                padding='same'
            )
            x = F.batch_norm(x, None, None, training=True)
            x = F.relu(x)
        x = torch.mean(x, dim=[2, 3])
        x = F.linear(
            input=x,
            weight=parameter[f'linear_weight'],
            bias=parameter[f'linear_bias'])
        return x

class MAML:
    """Model-Agnostic Meta-Learning network object
    -
    Input:
        1 ```number of inner steps``` (int): steps in inner loop\n
        2 ```inner learning rate``` (float): lr for inner loop\n
        3 ```learn inner learning rate``` (bool): learn inner lr\n
        4 ```outer learning rate``` (float): lr for outer loop\n
        5 ```network``` (Network_with_second_back_propagation): network to build model 
            and forward-propagate
    """
    def __init__(
            self,  
            number_of_inner_steps : int, 
            inner_learning_rate : float, learn_inner_learning_rate : bool, 
            outer_learning_rate : float, 
            network : Network_with_second_back_propagation
    ):
        self._network : Network_with_second_back_propagation = network
        self._meta_parameters = self._network.build_network()
        self._number_of_inner_steps = number_of_inner_steps
        self._inner_learning_rate : dict[str, torch.Tensor]= {
            k: torch.tensor(inner_learning_rate, requires_grad=learn_inner_learning_rate)
            for k in self._meta_parameters.keys()}
        self._outer_learning_rate = outer_learning_rate
        self._optimizer = torch.optim.Adam(
            list(self._meta_parameters.values()) +
            list(self._inner_learning_rate.values()),
            lr=self._outer_learning_rate)
        self._start_train_step = 0

    def _inner_loop(self, support_input : torch.Tensor, support_label : torch.Tensor, we_are_training : bool) -> tuple[dict[str, torch.Tensor], list[float]]:
        """Adapts network parameters to ONE task
        -
        Inputs:
            1 ```support_input``` (Tensor): task support set inputs
                shape (number of images, channels, height, width)
            2 ```support_label``` (Tensor): task support set labels
                shape (number of images,)
            3 ```we_are_training``` (bool): whether we are training or evaluating
                received from ```_outer_step()```
        Returns:
            1 ```parameter``` (dict[str, Tensor]): adapted network parameters.\n
            2 ```accuracy_list``` (list[float]): support set accuracy over the course of
                the inner loop, length num_inner_steps + 1
        """
        accuracy_list = []
        cloned_parameter = {
            k: torch.clone(v)
            for k, v in self._meta_parameters.items()
        }
        # This method computes the inner loop (adaptation) procedure
        # over the course of _num_inner_steps steps for one
        # task. It also scores the model along the way.
        # Use F.cross_entropy to compute classification losses.
        # Use util.score to compute accuracies.
        for _ in range(self._number_of_inner_steps):
            # Forward propagation to obtain y_support
            y_support = self._network.forward_propagation(incoming=support_input, parameter=cloned_parameter)
            # get loss
            inner_loss = F.cross_entropy(y_support, support_label)
            # get accuracy
            accuracy = util.score(y_support, support_label)
            accuracy_list.append(accuracy)
            # get gradients (for EACH layer)
            grads_list = autograd.grad(inner_loss, cloned_parameter.values(), create_graph=we_are_training)
            for (layer, layer_name) in enumerate(cloned_parameter.keys()):
                cloned_parameter[layer_name] = cloned_parameter[layer_name] - self._inner_learning_rate[layer_name] * grads_list[layer]
        # post-adaptation accuracy <- note that we're NOT concerned with the loss!
        y_support = self._network.forward_propagation(incoming=support_input, parameter=cloned_parameter)
        post_adapt_accuracy = util.score(y_support, support_label)
        accuracy_list.append(post_adapt_accuracy)
        return cloned_parameter, accuracy_list

    def _outer_step(self, task_batch : tuple[tuple[torch.Tensor, ...], ...], we_are_training : bool):
        """Get Loss from BATCH of Tasks.
        -
        Inputs:
            1 ```task batch``` (tuple): batch of tasks from an Omniglot DataLoader
                each task is an (images support, labels support, images query, labels query)!
            2 ```train``` (bool): whether we are training or evaluating
        Returns:
            1 ```outer loss``` (Tensor): mean MAML loss over the batch, 
                Scalar.
            2 ```accuracies support``` (ndarray): support set accuracy over inner loop steps, 
                averaged over the task batch dimension.
                shape (number of inner steps + 1,)
            3 ```accuracy query``` (float): query set accuracy of the adapted
                parameters, averaged over the task batch
        """
        outer_loss_batch_list = []
        accuracies_support_batch_list = []
        accuracy_query_batch_list = []
        for task in task_batch:
            images_support, labels_support, images_query, labels_query = task
            images_support = images_support.to(self._network._device)
            labels_support = labels_support.to(self._network._device)
            images_query = images_query.to(self._network._device)
            labels_query = labels_query.to(self._network._device)
            # For a given task, use the _inner_loop method to adapt for
            # _num_inner_steps steps, then compute the MAML loss and other metrics.
            # Use F.cross_entropy to compute classification losses.
            # Use util.score to compute accuracies.
            adapted_parameters, accuracy_support_over_inner_steps = self._inner_loop(support_input=images_support, support_label=labels_support, we_are_training=we_are_training)
            predicted_labels_query = self._network.forward_propagation(incoming=images_query, parameter=adapted_parameters)
            # get loss
            outer_loss_batch = F.cross_entropy(predicted_labels_query, labels_query)
            outer_loss_batch_list.append(outer_loss_batch) # NOT .item() because we need to back propagate!
            # get accuracy
            accuracies_support_batch_list.append(accuracy_support_over_inner_steps)
            accuracy_query_batch = util.score(predicted_labels_query, labels_query)
            accuracy_query_batch_list.append(accuracy_query_batch)
        outer_loss = torch.mean(torch.stack(outer_loss_batch_list))
        accuracies_support = np.mean(accuracies_support_batch_list, axis=0) # don't need back-propagate on accuracy
        accuracy_query = np.mean(accuracy_query_batch_list)
        return outer_loss, accuracies_support, accuracy_query

    def train(self, dataloader_train : dataloader.DataLoader, 
              dataloader_val : dataloader.DataLoader) -> tuple[list[float], ...]:
        """Train the MAML.
        -
        1. Optimizes MAML parameter with ```dataloader train``` 
        2. Periodically Validate on ```dataloader val```, logging metrics, and
        saving checkpoints.

        Inputs:
            1 ```dataloader train``` (DataLoader): loader for train tasks\n
            2 ```dataloader val``` (DataLoader): loader for validation tasks"""
        print(f'Starting training at iteration {self._start_train_step}.')
        loss_list = []
        accuracy_list = []
        for i_step, task_batch in enumerate(dataloader_train, start=self._start_train_step):
            self._optimizer.zero_grad()
            # 1. Optimize MAML parameter with dataloader_train
            outer_loss, accuracies_support, accuracy_query = self._outer_step(task_batch, we_are_training=True)
            outer_loss.backward()
            self._optimizer.step()
            # write down loss and accuracy
            loss_list.append(outer_loss.item())
            accuracy_list.append(accuracy_query)
            # 2. Periodically validate on dataloader_val
            if i_step % LOG_INTERVAL == 0:
                print(
                    f'Iteration {i_step}: '
                    f'loss: {outer_loss.item():.3f}, '
                    f'pre-adaptation support accuracy: '
                    f'{accuracies_support[0]:.3f}, '
                    f'post-adaptation support accuracy: '
                    f'{accuracies_support[-1]:.3f}, '
                    f'post-adaptation query accuracy: '
                    f'{accuracy_query:.3f}')
            if i_step % VAL_INTERVAL == 0:
                losses = []
                accuracies_pre_adapt_support = []
                accuracies_post_adapt_support = []
                accuracies_post_adapt_query = []
                for val_task_batch in dataloader_val:
                    outer_loss, accuracies_support, accuracy_query = self._outer_step(val_task_batch, we_are_training=False)
                    losses.append(outer_loss.item())
                    accuracies_pre_adapt_support.append(accuracies_support[0])
                    accuracies_post_adapt_support.append(accuracies_support[-1])
                    accuracies_post_adapt_query.append(accuracy_query)
                loss = np.mean(losses)
                accuracy_pre_adapt_support = np.mean(accuracies_pre_adapt_support)
                accuracy_post_adapt_support = np.mean(accuracies_post_adapt_support)
                accuracy_post_adapt_query = np.mean(accuracies_post_adapt_query)
                print(
                    f'Validation: '
                    f'loss: {loss:.3f}, '
                    f'pre-adaptation support accuracy: '
                    f'{accuracy_pre_adapt_support:.3f}, '
                    f'post-adaptation support accuracy: '
                    f'{accuracy_post_adapt_support:.3f}, '
                    f'post-adaptation query accuracy: '
                    f'{accuracy_post_adapt_query:.3f}')
        return loss_list, accuracy_list

    def test(self, dataloader_test : dataloader.DataLoader) -> list[float]:
        """Evaluate the MAML on test tasks.
        -
        Inputs:
            1 ```dataloader test``` (DataLoader): loader for test tasks
        """
        accuracy_list = []
        for task_batch in dataloader_test:
            _, _, accuracy_query = self._outer_step(task_batch, we_are_training=False)
            accuracy_list.append(accuracy_query)
        mean = np.mean(accuracy_list)
        std = np.std(accuracy_list)
        mean_95_confidence_interval = 1.96 * std / np.sqrt(NUM_TEST_TASKS)
        print(
            f'Accuracy over {NUM_TEST_TASKS} test tasks: '
            f'mean {mean:.3f}, '
            f'95% confidence interval {mean_95_confidence_interval:.3f}')
        return accuracy_list

Based on the code you are appending a few tensors to lists. Did you check if any of these tensors is still attached to a computation graph as it will not only attach the tensor itself to the list but also the entire computation graph disallowing PyTorch to free it.
If you don’t need to call backward on any of the list items and want to store it for logging etc., call .detach() on the tensor before attaching it.

1 Like

Thank you so much ptrblck! I had no idea these tensors affect the entire computation graph like that! :exploding_head:. I will fix this and see how the model does!