Hello,
I am trying to implement the meta-learning net from [1712.09926] Rapid Adaptation with Conditionally Shifted Neurons.
In the learning phase I am inspecting the model.parameters() and model.state_dict() and they both stay the same after the optimizer step.
The training part is described as:
model = CSNMetaLearner(n_way, device=device).to(device)
# model.apply(weights_init)
clip_value = 10
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss = nn.CrossEntropyLoss().to(device)
epochs = 100
for i in trange(epochs):
support_set_indices, task_indices, labels = emblematica.get_task(n_way, k_shot, m_test, random_seed=1)
few_shot_labels = dict(zip(labels, range(n_way)))
support_set = []
for j in support_set_indices:
support_sample = emblematica[j]
support_sample['label_index'] = few_shot_labels[support_sample['label_index']]
support_sample['label_vector'] = torch.nn.functional.one_hot(torch.tensor(support_sample['label_index']))
support_set.append(support_sample)
task = []
for k in task_indices:
task.append(emblematica[k])
#description_phase
model.description_phase(support_set)
#prediction_phase
for sample in tqdm(task, leave=False):
img = sample['image'][None, None, ...].float().to(device)
y_pred = model(img)
few_shot_label = few_shot_labels[sample['label_index']]
l = loss(y_pred, torch.tensor([few_shot_label], device=device))
with open('temp_state_pre', 'w') as file_:
file_.write(str(list(model.parameters())))
optimizer.zero_grad()
l.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
optimizer.step()
with open('temp_state_post', 'w') as file_:
file_.write(str(list(model.parameters())))
The main class looks like this:
import torch.nn as nn
import torch
import math
from adacnn import SmallerAdaptiveCNN as AdaptiveCNN
from adacnn import MemoryFunction
from tqdm import tqdm, trange
def weights_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
elif isinstance(m, nn.Linear):
torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
class CSNMetaLearner(nn.Module):
def __init__(self, number_of_classes, image_size=640, memory_hidden_units=20, device=None, key_size=64, image_channels=1):
super(CSNMetaLearner, self).__init__()
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
feature_size = 640 # image height / width
cnn_blocks = 4
for _ in range(cnn_blocks):
feature_size = math.floor((feature_size)/2)-1
self.base_learner = AdaptiveCNN(image_channels, number_of_classes, feature_size).to(self.device)
self.g = MemoryFunction(memory_hidden_units, number_of_classes).to(self.device)
self.f = AdaptiveCNN(image_channels, key_size, feature_size).to(self.device)
def description_phase(self, support_set):
self.support_set = support_set
self.keys = []
self.shifts = []
for sample in tqdm(self.support_set, leave=False):
# predicten
img = sample['image'][None, None, ...].float().to(self.device)
result, intermediate = self.base_learner(img)
# sample einbetten mit f
key, _ = self.f(img)
# DF bestimmen
layer_shifts = self.calc_direct_feedback_information(result, intermediate, sample['label_index'])
self.keys.append(key)
self.shifts.append(layer_shifts)
def calc_direct_feedback_information(self, y_pred, intermediate, y):
loss = nn.CrossEntropyLoss()
loss_value = loss(y_pred, torch.tensor([y]).to(self.device))
loss_derivate = torch.autograd.grad(loss_value, y_pred)[0]
layer_shift_templates = []
for intermediate_tensor in tqdm(intermediate, leave=False):
a_t = torch.flatten(intermediate_tensor, start_dim=2)
intermediate_shape = intermediate_tensor.shape
intermediate_shifts = []
for filter_ in tqdm(torch.squeeze(a_t), leave=False):
a_tl = filter_.requires_grad_(True).to(self.device)
a_sigma = nn.functional.relu(a_tl).to(self.device)
a_tgrads = torch.autograd.grad(a_sigma, a_tl, torch.ones(a_tl.shape[0], device=self.device))[0].to(self.device)
with torch.no_grad():
filter_shifts = []
for a_tgrad in a_tgrads:
i_tli = loss_derivate * a_tgrad
shift = self.g(i_tli)
filter_shifts.append(shift)
filter_shifts_vector = torch.cat(filter_shifts)
filter_shift = filter_shifts_vector.reshape(intermediate_shape[2:])
intermediate_shifts.append(filter_shift)
intermediate_shift = torch.stack(intermediate_shifts)
layer_shift_templates.append(intermediate_shift)
return layer_shift_templates
# prediction phase
def forward(self, x):
# shifts für sample bestimmen
shifts = self.get_shift_for_sample(x)
# shifts mitgeben
result, _ = self.base_learner(x, shifts)
return result
# currently gets the closest one (hard attention)
def get_shift_for_sample(self, x):
embedded_sample, _ = self.f(x)
cosine_similarity = nn.CosineSimilarity()
sim_ = 0
for i, key in enumerate(self.keys):
current_sim = cosine_similarity(embedded_sample, key)
if current_sim > sim_:
highest_sim_index = i
sim_ = current_sim
return self.shifts[highest_sim_index]
And it uses the following for the base learner and the memory function can be found at the end:
import torch
import torch.nn as nn
class AdaptiveCNNBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(AdaptiveCNNBlock, self).__init__()
self.blocks = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),
nn.ReLU(),
nn.MaxPool2d(2))
def forward(self, x):
result = self.blocks(x)
return result
class AdaptiveCNN(nn.Module):
def __init__(self, in_channels, out_features, image_size, intermediate_channels=64):
super(AdaptiveCNN, self).__init__()
self.blocks = nn.Sequential(AdaptiveCNNBlock(in_channels, intermediate_channels),
AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
nn.Flatten(),
nn.Linear(intermediate_channels * image_size**2, out_features),
nn.Softmax(dim=-1))
def forward(self, x, shifts=None):
intermediate = []
result = x
shifts_index = 0
for i, layer in enumerate(self.blocks):
result = layer(result)
if 0 < i < 5:
if shifts is not None:
result = nn.functional.relu(result) + nn.functional.relu(shifts[shifts_index])
shifts_index += 1
else:
intermediate.append(result)
return result, intermediate
class SmallerAdaptiveCNN(nn.Module):
def __init__(self, in_channels, out_features, image_size, intermediate_channels=64):
super(SmallerAdaptiveCNN, self).__init__()
self.blocks = nn.Sequential(AdaptiveCNNBlock(in_channels, intermediate_channels),
AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
nn.Flatten(),
nn.Linear(intermediate_channels * image_size**2, out_features),
nn.Softmax(dim=-1))
def forward(self, x, shifts=None):
intermediate = []
result = x
shifts_index = 0
for i, layer in enumerate(self.blocks):
result = layer(result)
if 0 < i < 4:
if shifts is not None:
result = torch.add(nn.functional.relu(result), nn.functional.relu(shifts[shifts_index]))
shifts_index += 1
else:
intermediate.append(result)
return result, intermediate
class MemoryFunction(nn.Module):
def __init__(self, units_per_layer, information_size):
super(MemoryFunction, self).__init__()
self.blocks = nn.Sequential(nn.Linear(information_size, units_per_layer),
nn.Linear(units_per_layer, units_per_layer),
nn.Linear(units_per_layer, 1))
def forward(self, x):
result = self.blocks(x)
return result
I dont get any error message, but the outputs from the files that are written in the prediction phase are the same.
The loss is greater than 0.
If anyone could point out the error or has any idea what the problem could be, i would be very thankful.
Thanks in advance!