Is there a way to view which activations are stored by a model during the forward pass? I’d like to figure out how to optimize the memory consumption of my model.
A simple example from the repo.
import torchlens as tl
import torch.nn as nn
import torch
class SimpleRecurrent(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(in_features=5, out_features=5)
def forward(self, x):
for r in range(4):
x = self.fc(x)
x = x + 1
x = x * 2
return x
x = torch.rand(6, 5)
simple_recurrent = SimpleRecurrent()
model_history = tl.log_forward_pass(simple_recurrent, x,
layers_to_save='all',
vis_opt='rolled')
print(model_history['linear_1_1:2'].tensor_contents) # second pass of first linear layer
'''
tensor([[-0.0690, -1.3957, -0.3231, -0.1980, 0.7197],
[-0.1083, -1.5051, -0.2570, -0.2024, 0.8248],
[ 0.1031, -1.4315, -0.5999, -0.4017, 0.7580],
[-0.0396, -1.3813, -0.3523, -0.2008, 0.6654],
[ 0.0980, -1.4073, -0.5934, -0.3866, 0.7371],
[-0.1106, -1.2909, -0.3393, -0.2439, 0.7345]])
'''