import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
#load trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = torch.load('resnet_final.pt').to(device)
# Exclude subgraphs for feature extraction
for param in model.parameters():
param.requires_grad = False
def extract_sizes(x, model):
input_features = []
output_features = []
def hook(module, input, output):
input_features.append(input[0].view(input[0].shape[0], -1))
output_features.append(output)
# Set the model to evaluation mode
model.eval()
# Register a forward hook on each module
for i, module in enumerate(model.modules()):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
module.register_forward_hook(hook)
# Use the model on a single input tensor
model(x)
return input_features, output_features
# Extract features from an input tensor
x = torch.randn(1, 3, 32, 32).to(device)
input_features, output_features = extract_sizes(x, model)
# Print the shape of the features at each layer
for i, f in enumerate(input_features):
print(f"Input: {f.shape} | Output: {output_features[i].shape}")
def extract_features(input_batch, model):
input_features = []
def hook(module, input, output):
input[0].detach()
input_features.append(input[0])
model.eval()
hooks = []
for i, module in enumerate(model.modules()):
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
hooks.append(module.register_forward_hook(hook))
model(input_batch)
# remove all hooks
for hook in hooks:
hook.remove()
return input_features
# CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data/',
train=True,
transform=transforms.ToTensor(),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=100,
shuffle=True)
# Create list of probes and optimizers
probes = []
for i in range(len(input_features)):
probes.append(nn.Linear(input_features[i].shape[1], 10).to(device))
probe_optimizers = []
for probe in probes:
probe_optimizers.append(torch.optim.Adam(probe.parameters(), lr=0.001))
loss_fn = nn.CrossEntropyLoss()
def train_probes(input_features, probes, probe_optimizers, loss_fn):
for i, input in enumerate(input_features):
out = probes[i](input.view(input.shape[0], -1))
loss = loss_fn(out, labels)
probe_optimizers[i].zero_grad()
loss.backward(retain_graph=True)
probe_optimizers[i].step()
num_epochs = 10
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
with torch.no_grad():
images = images.to(device)
labels = labels.to(device)
input_features = extract_features(images, model)
print(torch.cuda.memory_summary(device=device, abbreviated=True))
train_probes(input_features, probes, probe_optimizers, loss_fn)
print(torch.cuda.memory_summary(device=device, abbreviated=True))
print('Epoch {i} Done')
I have the above code, where I am trying to extract features from a trained Resnet18 on CIFAR10. I have noticed two issues where memory is leaking.
The first is from the function extract features, which allocates memory on each training loop rather than reuse allocated memory for the tensors where features are extracted.
The second is in more minor, but still an issue in train_probes.
I have tried detaching everything from the graph where possible, except in train_probes.
For context I am trying to implement the following paper (https://arxiv.org/pdf/1610.01644.pdf), where features extracted from every layer are trained on linear layers. The general process is
- extract features from layer x
- Train linear probe at layer x using extracted features and class label