Pytorch Memory Leak -> Feature extraction with forward hook

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
#load trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.load('resnet_final.pt').to(device)
# Exclude subgraphs for feature extraction
for param in model.parameters():
    param.requires_grad = False
def extract_sizes(x, model):
    input_features = []
    output_features = []

    def hook(module, input, output):
        input_features.append(input[0].view(input[0].shape[0], -1))
        output_features.append(output)

    # Set the model to evaluation mode
    model.eval()

    # Register a forward hook on each module
    for i, module in enumerate(model.modules()):
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            module.register_forward_hook(hook)

    # Use the model on a single input tensor
    model(x)

    return input_features, output_features

# Extract features from an input tensor
x = torch.randn(1, 3, 32, 32).to(device)
input_features, output_features = extract_sizes(x, model)

# Print the shape of the features at each layer
for i, f in enumerate(input_features):
    print(f"Input: {f.shape} | Output: {output_features[i].shape}")

def extract_features(input_batch, model):
    input_features = []
    
    def hook(module, input, output):
        input[0].detach()
        input_features.append(input[0])
    
    model.eval()
    hooks = []
    for i, module in enumerate(model.modules()):
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            hooks.append(module.register_forward_hook(hook))
    
    model(input_batch)
    # remove all hooks
    for hook in hooks:
        hook.remove()
    return input_features

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data/',
                                             train=True, 
                                             transform=transforms.ToTensor(),
                                             download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)

# Create list of probes and optimizers
probes = []
for i in range(len(input_features)):
    probes.append(nn.Linear(input_features[i].shape[1], 10).to(device))

probe_optimizers = []
for probe in probes:
    probe_optimizers.append(torch.optim.Adam(probe.parameters(), lr=0.001))

loss_fn = nn.CrossEntropyLoss()

def train_probes(input_features, probes, probe_optimizers, loss_fn):
    for i, input in enumerate(input_features):
        out = probes[i](input.view(input.shape[0], -1))
        loss = loss_fn(out, labels)
        probe_optimizers[i].zero_grad()
        loss.backward(retain_graph=True)
        probe_optimizers[i].step()
            
num_epochs = 10
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        with torch.no_grad():
            images = images.to(device)
            labels = labels.to(device)
            input_features = extract_features(images, model)
        print(torch.cuda.memory_summary(device=device, abbreviated=True))
        train_probes(input_features, probes, probe_optimizers, loss_fn)
        print(torch.cuda.memory_summary(device=device, abbreviated=True))
    print('Epoch {i} Done')

I have the above code, where I am trying to extract features from a trained Resnet18 on CIFAR10. I have noticed two issues where memory is leaking.
The first is from the function extract features, which allocates memory on each training loop rather than reuse allocated memory for the tensors where features are extracted.
The second is in more minor, but still an issue in train_probes.

I have tried detaching everything from the graph where possible, except in train_probes.

For context I am trying to implement the following paper (https://arxiv.org/pdf/1610.01644.pdf), where features extracted from every layer are trained on linear layers. The general process is

  1. extract features from layer x
  2. Train linear probe at layer x using extracted features and class label

You are not running into a memory leak but an expected increase in memory usage since you are storing tensors without detaching them. PyTorch will not be able to free the computation graph in the backward() call due to this. Additionally, you are explicitly retaining the graph via backward(retain_graph=True), so even if you fix the detaching issue the computation graph will stay alive.

tensor.detach() is not an inplace operation so you would have to either assign the output or use the inplace method via tensor.detach_().

So I finally solved this issue using the latest versions of torch with feature extraction. My code for that is below.

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('resnet_final.pt').to(device)

train_nodes, eval_nodes = get_graph_node_names(model)

# filter all conv and fc layers in eval_nodes
filtered = filter(lambda x: 'conv' in x or 'fc' in x, eval_nodes)
nodes = list(filtered)

return_nodes = {v: f'layer{k}' for k, v in enumerate(nodes)}
fx = create_feature_extractor(model, return_nodes)
with torch.no_grad():
    out = fx(torch.randn(1, 3, 32, 32).to(device))

features = [t.view(1,-1) for t in out.values()]
probes = [nn.Linear(f.shape[1], 10).to(device) for f in features]
optims = [torch.optim.Adam(p.parameters(), lr=0.001) for p in probes]

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data/',
                                             train=True, 
                                             transform=transforms.ToTensor(),
                                             download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)
num_epochs = 10
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        with torch.no_grad():
            out = fx(images)
        features = [t.view(t.shape[0], -1) for t in out.values()]
        for j, (f, p, o) in enumerate(zip(features, probes, optims)):
            out = p(f)
            loss = nn.CrossEntropyLoss()(out, labels)
            o.zero_grad()
            loss.backward()
            o.step()
            if (i+1) % 100 == 0:
                print(f'Probe [{j+1}/{len(probes)}], Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

I was still unable to get the feature extraction working in a loop using the forward hook. Not quite sure what I did wrong, as I separated everything I possibly could, printed graphs of the backward for each probe (to ensure no overlaps), but the extracted feature vectors still never properly deallocated. Non-working code is below:

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot

#load trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.load('resnet_final.pt').to(device)
# Exclude subgraphs for feature extraction
for param in model.parameters():
    param.requires_grad = False

def extract_sizes(x, model):
    input_features = []
    output_features = []

    def hook(module, input, output):
        input_features.append(input[0].view(input[0].shape[0], -1))
        output_features.append(output)

    # Set the model to evaluation mode
    model.eval()

    # Register a forward hook on each module
    for i, module in enumerate(model.modules()):
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            module.register_forward_hook(hook)

    # Use the model on a single input tensor
    model(x)

    return input_features, output_features

# Extract features from an input tensor
x = torch.randn(1, 3, 32, 32).to(device)
input_features, output_features = extract_sizes(x, model)

# Print the shape of the features at each layer
for i, f in enumerate(input_features):
    print(f"Input: {f.shape} | Output: {output_features[i].shape}")

def extract_features(input_batch, model):
    with torch.no_grad():
        input_features = []
        
        def hook(module, input, output):
            input_features.append(input[0].detach())
        
        model.eval()
        hooks = []
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                hooks.append(module.register_forward_hook(hook))
        with torch.no_grad():
            model(input_batch)
        # remove all hooks
        for hook in hooks:
            hook.remove()
        return input_features

# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data/',
                                             train=True, 
                                             transform=transforms.ToTensor(),
                                             download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=100, 
                                           shuffle=True)

# Create list of probes and optimizers
probes = []
for i in range(len(input_features)):
    probes.append(nn.Linear(input_features[i].shape[1], 10).to(device))

probe_optimizers = []
for probe in probes:
    probe_optimizers.append(torch.optim.Adam(probe.parameters(), lr=0.001))

loss_fns = [nn.CrossEntropyLoss() for i in range(len(input_features))]

def train_probes(input_features, probes, probe_optimizers, loss_fn, labels):
    for i, input in enumerate(input_features):
        out = probes[i](input.view(input.shape[0], -1))
        loss = loss_fn[i](out, labels)
        probe_optimizers[i].zero_grad()
        loss.backward()
        probe_optimizers[i].step()
            
num_epochs = 10
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        with torch.no_grad():
            images = images.to(device)
            labels = labels.to(device)
            input_features = extract_features(images, model)
        print(torch.cuda.memory_summary(device=device, abbreviated=True))
        train_probes(input_features, probes, probe_optimizers, loss_fns, labels)
        print(torch.cuda.memory_summary(device=device, abbreviated=True))
    print('Epoch {i} Done')