Here’s a script that reproduces the bug @albanD
import torch.nn as nn
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
from torch.utils.data import DataLoader, RandomSampler
# You should be able to ignore this
class SimpleModel(nn.Module):
def __init__(self, sizes, oldWeights=None, oldBiases=None):
super(SimpleModel, self).__init__()
self.phase = 'ACTION'
self.sizes = sizes
# Encoder
encoder_layers = self.set_module('encoder', oldWeights=oldWeights, oldBiases=oldBiases)
encoder_layers.append(nn.Sigmoid())
self.encoder = nn.Sequential(*encoder_layers)
# Make float
self.float()
def forward(self, x):
x = self.encoder(x)
return x
def set_module(self, label, oldWeights=None, oldBiases=None):
sizes = self.sizes[label]
if oldWeights:
oldWeights = oldWeights[label]
if oldBiases:
oldBiases = oldBiases[label]
layers = [
self.get_layer(sizes[0], sizes[1], oldWeights, oldBiases, 0)
]
for i in range(1, len(sizes) - 1):
layers.append(nn.LeakyReLU())
layers.append(self.get_layer(sizes[i], sizes[i+1], oldWeights, oldBiases, i))
return layers
def get_layer(self, input, output, init_weights=None, init_biases=None, index=0):
layer = nn.Linear(input, output)
torch.nn.init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='leaky_relu')
if init_weights is not None:
weights = init_weights[index]
# Type checking
if isinstance(weights, list):
weights = np.asarray(weights, dtype=float)
if isinstance(weights, np.ndarray):
weights = torch.from_numpy(weights)
if isinstance(weights, torch.Tensor):
weights = nn.Parameter(weights)
if isinstance(weights, torch.nn.Parameter):
layer.weight = weights
# Padding
weights = layer.weight.detach()
if input != weights.shape[1]:
kaiming_weights = torch.rand(weights.shape[0], input - weights.shape[1]).to(weights.device)
torch.nn.init.kaiming_uniform_(kaiming_weights, mode='fan_in', nonlinearity='leaky_relu')
weights = torch.cat([weights.float(), kaiming_weights.float()], dim=1)
if output != weights.shape[0]:
kaiming_weights = torch.rand(output - weights.shape[0], input).to(weights.device)
torch.nn.init.kaiming_uniform_(kaiming_weights, mode='fan_in', nonlinearity='leaky_relu')
weights = torch.cat([weights.float(), kaiming_weights.float()], dim=0)
# Set
layer.weight = nn.Parameter(weights)
if init_biases is not None:
biases = init_biases[index]
# Type checking
if isinstance(biases, list):
biases = np.asarray(biases, dtype=float)
if isinstance(biases, np.ndarray):
biases = torch.from_numpy(biases)
if isinstance(biases, torch.Tensor):
biases = nn.Parameter(biases)
if isinstance(biases, torch.nn.Parameter):
layer.bias = biases
# Padding
biases = layer.bias.detach()
if output != biases.shape[0]:
rand_biases = torch.rand(output - biases.shape[0]).to(biases.device)
biases = torch.cat([biases.float(), rand_biases.float()], dim=0)
# Set
layer.bias = nn.Parameter(biases)
# Update the oldBiases to include padding
init_biases[index] = layer.bias.detach()
return layer
def get_modules(model):
modules = {}
for name, param in model.named_parameters():
module = name[0: name.index('.')]
if module not in modules.keys():
modules[module] = []
modules[module].append((name, param))
return modules
# Hook class
class ActiveGradsHook:
"""
Resets the gradient according to the passed masks.
"""
def __init__(self, previously_active: [bool], currently_active: [bool], bias=False):
# Could be None for biases
if previously_active is not None:
self.previously_active = torch.BoolTensor(previously_active).long().nonzero().view(-1).numpy()
# Should never be None
self.currently_active = torch.BoolTensor(currently_active).long().nonzero().view(-1).numpy()
self.is_bias = bias
# self.__name__ = None # THIS WILL FIX YOUR CRASH
def __call__(self, grad):
grad_clone = grad.clone().detach()
if self.is_bias:
grad_clone[self.currently_active] = 0
else:
grad_clone[self.currently_active, :] = 0
grad_clone[:, self.previously_active] = 0
return grad_clone
def train_new_neurons(old_sizes, new_sizes, model, batch_loader, optimizer):
if old_sizes == new_sizes:
print("No new neurons to train.")
return (None, None)
print("Training new neurons...")
# Generate hooks for each layer
hooks = []
modules = get_modules(model)
for module_name, parameters in modules.items():
previously_active_weights = [False] * new_sizes[module_name][0]
for param_name, param in parameters:
split_param_name = param_name.split(".") # Splits action.0.weights
param_index = int(split_param_name[1])
# Map every two indices to one
param_index -= param_index % 2
param_index /= 2
param_index = int(param_index)
old_size = old_sizes[module_name][param_index + 1]
new_size = new_sizes[module_name][param_index + 1]
neurons_added = new_size - old_size
# Input/Output must stay the same
if param_index == len(old_sizes[module_name]) - 1:
assert old_size == new_size
continue
# Freeze biases/weights
if "bias" in param_name:
active_biases = [False] * old_size + [True] * neurons_added
hook = ActiveGradsHook(None, active_biases, bias=True)
hook = param.register_hook(hook)
hooks.append(hook)
else:
active_weights = [False] * old_size + [True] * neurons_added
hook = ActiveGradsHook(previously_active_weights, active_weights, bias=False)
hook = param.register_hook(hook)
hooks.append(hook)
previously_active_weights = active_weights
# Train simply
train(batch_loader, model, optimizer)
# Remove hooks
for hook in hooks:
hook.remove()
def train(batch_loader, model, optimizer):
for batch_idx, (inputs, action_target) in enumerate(batch_loader):
with torch.autograd.detect_anomaly():
action_output = model(inputs)
criterion = torch.nn.BCELoss()
action_loss = criterion(action_output, action_target)
# Compute gradient and do SGD step
optimizer.zero_grad()
action_loss.backward()
optimizer.step()
def get_mnist_loader():
dataset = datasets.MNIST
def one_hot_mnist(targets):
targets_onehot = torch.zeros(10)
targets_onehot[targets] = 1
return targets_onehot
transform_all = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
transforms.Lambda(lambda a: a.view(-1))
])
root = os.path.join(os.path.dirname(__file__), "../../data/MNIST")
dataset = dataset(root=root, train=True, download=True, transform=transform_all, target_transform=one_hot_mnist)
sampler = RandomSampler(dataset)
return DataLoader(dataset, sampler=sampler, batch_size=128, num_workers=0, pin_memory=False)
if __name__ == "__main__":
sizes = {"encoder": [784, 312, 128, 10]}
new_sizes = {"encoder": [784, 624, 256, 10]}
model = SimpleModel(sizes, None, None)
optimizer = torch.optim.SGD(model.parameters(), lr=1, momentum=0)
mnist_loader = get_mnist_loader()
train(mnist_loader, model, optimizer)
train_new_neurons(sizes, new_sizes, model, mnist_loader, optimizer)