Way to visualize which activations are stored during the forward pass?

Is there a way to view which activations are stored by a model during the forward pass? I’d like to figure out how to optimize the memory consumption of my model.

A simple example from the repo.

import torchlens as tl
import torch.nn as nn
import torch

class SimpleRecurrent(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(in_features=5, out_features=5)

    def forward(self, x):
        for r in range(4):
            x = self.fc(x)
            x = x + 1
            x = x * 2
        return x

x = torch.rand(6, 5)
simple_recurrent = SimpleRecurrent()
model_history = tl.log_forward_pass(simple_recurrent, x,
                                    layers_to_save='all',
                                    vis_opt='rolled')
print(model_history['linear_1_1:2'].tensor_contents)  # second pass of first linear layer

'''
tensor([[-0.0690, -1.3957, -0.3231, -0.1980,  0.7197],
        [-0.1083, -1.5051, -0.2570, -0.2024,  0.8248],
        [ 0.1031, -1.4315, -0.5999, -0.4017,  0.7580],
        [-0.0396, -1.3813, -0.3523, -0.2008,  0.6654],
        [ 0.0980, -1.4073, -0.5934, -0.3866,  0.7371],
        [-0.1106, -1.2909, -0.3393, -0.2439,  0.7345]])
'''