Parameters dont change while training

Hello,

I am trying to implement the meta-learning net from [1712.09926] Rapid Adaptation with Conditionally Shifted Neurons.
In the learning phase I am inspecting the model.parameters() and model.state_dict() and they both stay the same after the optimizer step.

The training part is described as:


    model = CSNMetaLearner(n_way, device=device).to(device)
    # model.apply(weights_init)

    clip_value = 10
    learning_rate = 0.001
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    loss = nn.CrossEntropyLoss().to(device)

    epochs = 100

    for i in trange(epochs):
        support_set_indices, task_indices, labels = emblematica.get_task(n_way, k_shot, m_test, random_seed=1)
        few_shot_labels = dict(zip(labels, range(n_way)))

        support_set = []
        for j in support_set_indices:
            support_sample = emblematica[j]
            support_sample['label_index'] = few_shot_labels[support_sample['label_index']]
            support_sample['label_vector'] = torch.nn.functional.one_hot(torch.tensor(support_sample['label_index']))
            support_set.append(support_sample)

        task = []
        for k in task_indices:
            task.append(emblematica[k])

        #description_phase
        model.description_phase(support_set)

        #prediction_phase
        for sample in tqdm(task, leave=False):
            img = sample['image'][None, None, ...].float().to(device)
            y_pred = model(img)

            few_shot_label = few_shot_labels[sample['label_index']]
            l = loss(y_pred, torch.tensor([few_shot_label], device=device))
            with open('temp_state_pre', 'w') as file_:
                file_.write(str(list(model.parameters())))
            optimizer.zero_grad()

            l.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            optimizer.step()
            with open('temp_state_post', 'w') as file_:
                file_.write(str(list(model.parameters())))

The main class looks like this:

import torch.nn as nn
import torch
import math
from adacnn import SmallerAdaptiveCNN as AdaptiveCNN
from adacnn import MemoryFunction
from tqdm import tqdm, trange



def weights_init(m):
    if isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
    elif isinstance(m, nn.Linear):
        torch.nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')

class CSNMetaLearner(nn.Module):
    def __init__(self, number_of_classes, image_size=640, memory_hidden_units=20, device=None, key_size=64, image_channels=1):
        super(CSNMetaLearner, self).__init__()
        if device is None:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device
        feature_size =  640 # image height / width
        cnn_blocks = 4
        for _ in range(cnn_blocks):
            feature_size = math.floor((feature_size)/2)-1
        self.base_learner = AdaptiveCNN(image_channels, number_of_classes, feature_size).to(self.device)
        self.g = MemoryFunction(memory_hidden_units, number_of_classes).to(self.device)
        self.f = AdaptiveCNN(image_channels, key_size, feature_size).to(self.device)


    def description_phase(self, support_set):
        self.support_set = support_set

        self.keys = []
        self.shifts = []

        for sample in tqdm(self.support_set, leave=False):
            # predicten
            img = sample['image'][None, None, ...].float().to(self.device)
            result, intermediate = self.base_learner(img)
            # sample einbetten mit f
            key, _ = self.f(img)
            # DF bestimmen
            layer_shifts = self.calc_direct_feedback_information(result, intermediate, sample['label_index'])

            self.keys.append(key)
            self.shifts.append(layer_shifts)


    def calc_direct_feedback_information(self, y_pred, intermediate, y):
        loss = nn.CrossEntropyLoss()
        loss_value = loss(y_pred, torch.tensor([y]).to(self.device))
        loss_derivate = torch.autograd.grad(loss_value, y_pred)[0]

        layer_shift_templates = []
        for intermediate_tensor in tqdm(intermediate, leave=False):
            a_t = torch.flatten(intermediate_tensor, start_dim=2)
            intermediate_shape = intermediate_tensor.shape
            intermediate_shifts = []
            for filter_ in tqdm(torch.squeeze(a_t), leave=False):
                a_tl = filter_.requires_grad_(True).to(self.device)
                a_sigma = nn.functional.relu(a_tl).to(self.device)
                a_tgrads = torch.autograd.grad(a_sigma, a_tl, torch.ones(a_tl.shape[0], device=self.device))[0].to(self.device)
                with torch.no_grad():
                    filter_shifts = []
                    for a_tgrad in a_tgrads:
                        i_tli = loss_derivate * a_tgrad
                        shift = self.g(i_tli)
                        filter_shifts.append(shift)
                    filter_shifts_vector = torch.cat(filter_shifts)
                    filter_shift = filter_shifts_vector.reshape(intermediate_shape[2:])
                    intermediate_shifts.append(filter_shift)
            intermediate_shift = torch.stack(intermediate_shifts)
            layer_shift_templates.append(intermediate_shift)
        return layer_shift_templates

    # prediction phase
    def forward(self, x):
        # shifts für sample bestimmen
        shifts = self.get_shift_for_sample(x)
        # shifts mitgeben
        result, _ = self.base_learner(x, shifts)
        return result

    # currently gets the closest one (hard attention)
    def get_shift_for_sample(self, x):
        embedded_sample, _ = self.f(x)
        cosine_similarity = nn.CosineSimilarity()
        sim_ = 0
        for i, key in enumerate(self.keys):
            current_sim = cosine_similarity(embedded_sample, key)
            if current_sim > sim_:
                highest_sim_index = i
                sim_ = current_sim
        return self.shifts[highest_sim_index]

And it uses the following for the base learner and the memory function can be found at the end:

import torch
import torch.nn as nn

class AdaptiveCNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AdaptiveCNNBlock, self).__init__()
        self.blocks = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3),
                                    nn.ReLU(),
                                    nn.MaxPool2d(2))

    def forward(self, x):
        result = self.blocks(x)
        return result

class AdaptiveCNN(nn.Module):
    def __init__(self, in_channels, out_features, image_size, intermediate_channels=64):
        super(AdaptiveCNN, self).__init__()
        self.blocks = nn.Sequential(AdaptiveCNNBlock(in_channels, intermediate_channels),
                                    AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
                                    AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
                                    AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
                                    AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
                                    nn.Flatten(),
                                    nn.Linear(intermediate_channels * image_size**2, out_features),
                                    nn.Softmax(dim=-1))

    def forward(self, x, shifts=None):
        intermediate = []
        result = x
        shifts_index = 0
        for i, layer in enumerate(self.blocks):
            result = layer(result)
            if 0 < i < 5:
                if shifts is not None:
                    result = nn.functional.relu(result) + nn.functional.relu(shifts[shifts_index])
                    shifts_index += 1
                else:
                    intermediate.append(result)
        return result, intermediate


class SmallerAdaptiveCNN(nn.Module):
    def __init__(self, in_channels, out_features, image_size, intermediate_channels=64):
        super(SmallerAdaptiveCNN, self).__init__()
        self.blocks = nn.Sequential(AdaptiveCNNBlock(in_channels, intermediate_channels),
                                    AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
                                    AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
                                    AdaptiveCNNBlock(intermediate_channels, intermediate_channels),
                                    nn.Flatten(),
                                    nn.Linear(intermediate_channels * image_size**2, out_features),
                                    nn.Softmax(dim=-1))

    def forward(self, x, shifts=None):
        intermediate = []
        result = x
        shifts_index = 0
        for i, layer in enumerate(self.blocks):
            result = layer(result)
            if 0 < i < 4:
                if shifts is not None:
                    result = torch.add(nn.functional.relu(result), nn.functional.relu(shifts[shifts_index]))
                    shifts_index += 1
                else:
                    intermediate.append(result)
        return result, intermediate

class MemoryFunction(nn.Module):
    def __init__(self, units_per_layer, information_size):
        super(MemoryFunction, self).__init__()
        self.blocks = nn.Sequential(nn.Linear(information_size, units_per_layer),
                                   nn.Linear(units_per_layer, units_per_layer),
                                   nn.Linear(units_per_layer, 1))

    def forward(self, x):
        result = self.blocks(x)
        return result

I dont get any error message, but the outputs from the files that are written in the prediction phase are the same.
The loss is greater than 0.
If anyone could point out the error or has any idea what the problem could be, i would be very thankful.

Thanks in advance!