Repalce default F.Conv2d weights with pretrained imagenet weights

I would like to initialize the layers of a ResNet18 with pretrained imagenet weights. However, my resnt18’s layers have been rewritten like this:

class FreezableConv2d(torch.nn.Conv2d):
    """
        Conv2d layer with selectively frozen outputs.

        Ones in self.mask indicate, which outputs are frozen.
    """

    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):

        super().__init__(in_channels, out_channels, kernel_size=kernel_size,
                         stride=stride, padding=padding, dilation=dilation,
                         groups=groups, bias=bias, padding_mode=padding_mode)
        # about register_buffer(): https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723
        self.register_buffer('mask', torch.zeros(out_channels, 1, 1, 1))
        self.register_buffer('shadow_weight', torch.Tensor(*self.weight.shape)) # what is shadow_weight? 
        self.register_buffer('shadow_bias', torch.Tensor(*self.bias.shape)) # what is shadow_bias? 

        self.reset_shadow_parameters()

    def reset_shadow_parameters(self):
        self.shadow_weight.data.copy_(self.weight)
        self.shadow_bias.data.copy_(self.bias)

    def copy_weight2shadow_by_idx(self, idx):
        self.shadow_weight.data[idx] = self.weight.data[idx].clone()
        self.shadow_bias.data[idx] = self.bias.data[idx].clone()

    @property
    def frozen_weight(self):
        return self.shadow_weight * self.mask + self.weight * (1. - self.mask)

    @property
    def frozen_bias(self):
        return self.shadow_bias * self.mask.squeeze() + self.bias * (1. - self.mask.squeeze())

    def freeze(self, mask):
        r"""
        Updates is_frozen mask.
        Assumes, that each new mask is based on previous with some
        more elements to be frozen.
        Args:
            mask: torch.Tensor, ones in the mask indicate the kernels to be frozen

        Returns:
            None
        """
        assert mask.ndim == 1
        new_mask = mask.detach().to(self.weight)
        diff = (self.mask[:, 0, 0, 0] - new_mask)
        # -1 in diff means that there is a new kernel to be frozen
        new_idx_to_freeze = (diff == -1).nonzero(as_tuple=False)[:, 0].long()
        self.copy_weight2shadow_by_idx(new_idx_to_freeze)

        mask = new_mask[:, None, None, None]
        self.register_buffer('mask', mask.contiguous())

    def reinit_unfrozen(self):
        r"""
        Reinitialize the self.weight and self.bias.

        shadow_weight and shadow_bias already include frozen weights,
        which will be substituted in the proper places
        during frozen_weight/frozen_bias call, so whole self.weight and self.bias
        can be reinitialized.
        Returns:
            None
        """
        self.reset_parameters()

    def forward(self, x):
        return F.conv2d(x, self.frozen_weight, bias=self.frozen_bias,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        groups=self.groups)

Then used in another class like this:

class BaseGatedConv(nn.Module):
    """
        Base class for single Task-gated conv layer.
        Incorporates all gating logic.
    """
    def __init__(self, in_ch, out_ch, freezing_method,
                 aggregate_firing=False,
                 N_tasks=1,
                 conv_params: dict = None):
        super().__init__()

        self.N_tasks = N_tasks
        self.in_ch = in_ch
        self.out_ch = out_ch

        if conv_params:
            self.conv2d = FreezableConv2d(in_ch, out_ch, **conv_params)
        else:
            self.conv2d = FreezableConv2d(in_ch, out_ch, kernel_size=3, padding=1)

        self.main_conv_path = nn.Sequential(self.conv2d)

        self.fbns = nn.ModuleList([create_freezable_bn(self.out_ch) for _ in range(self.N_tasks)])

        self.gates = nn.ModuleList([self.create_gate_fc() for _ in range(self.N_tasks)])

        self.taskwise_sparse_objective = BufferList([torch.empty((1))] * self.N_tasks) # used for sparsity objective

        # aggregates frequencies with which kernels were chosen
        self.aggregate_firing = aggregate_firing
        self.channels_firing_freq = torch.zeros((self.N_tasks, self.out_ch))
        self.n_aggregations = torch.zeros((self.N_tasks))  # for calculating probabilities correctly

        self.freezing_method = freezing_method
        
        self.register_buffer('frozen_kernels_mask',
                             torch.zeros((self.out_ch), dtype=int))

        # maybe here I cna change the weights? FreezableConv2D(in_ch, out_ch, **) is a child of nn.torch.conv2d
        # I can probably access the weights and biases of Freezableconv2D as I can access any conv2d of pytorch. 
        # 

    def create_gate_fc(self):
        gate_fc = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(self.in_ch, 16),
            nn.BatchNorm1d(16, track_running_stats=False) if cfg.USE_BATCHNORM_GATES else nn.Identity(),
            nn.ReLU(),
            nn.Linear(16, self.out_ch)
        )
        return gate_fc

    def add_task_path(self):
        r"""Add task path to the block"""
        r"""Following .to(some_weight) operation is to ensure, that all the gates and fbns are on the same device
           This method, however, will not work with fbns, if affine=False, so be aware of that."""
        self.gates.append(self.create_gate_fc().to(cfg.DEVICE))
        self.fbns.append(create_freezable_bn(self.out_ch).to(cfg.DEVICE))

        self.taskwise_sparse_objective.append(torch.empty((1)))

        self.channels_firing_freq = torch.cat([self.channels_firing_freq, 
                                               torch.zeros((1, self.out_ch))], 0)
        self.n_aggregations = torch.cat([self.n_aggregations,
                                        torch.zeros((1))], 0)

        self.N_tasks += 1

    def enable_gates_firing_tracking(self):
        self.aggregate_firing = True

    def reset_gates_firing_tracking(self):
        self.aggregate_firing = False
        self.channels_firing_freq = torch.zeros((self.N_tasks, self.out_ch))
        self.n_aggregations = torch.zeros((self.N_tasks))

    def aggregate_channels_firing_stat(self, channels_mask, task_idx):
        """
            Sums up frequencies of choosing kernels among batches
            during validation or test.

            Attributes:
            channels_mask - binary mask
        """
        self.channels_firing_freq[task_idx] += channels_mask.float().mean(dim=0).detach().cpu()
        self.n_aggregations[task_idx] += 1

    def update_relevant_kernels(self, task_id):
        """
            Updates relevant kernels according to each gate-path e.g. task
        """
        if self.freezing_method.freeze_fixed_proc:
            k = int(self.out_ch * self.freezing_method.freeze_top_proc)
            aggregated_times = self.n_aggregations[task_id]
            threshold = self.freezing_method.freeze_prob_thr * aggregated_times

            gate_stat = self.channels_firing_freq[task_id].clone()
            n_relevant = (gate_stat > threshold).long().sum()
            # gate_stat[gate_stat < threshold] = 0

            if n_relevant > k:
                print(f'Not enough capacity for relevant kernels: {n_relevant}/{k} ')
                idx_to_freeze = torch.topk(gate_stat, k, dim=-1)[1]
            else:
                idx_to_freeze = torch.topk(gate_stat, n_relevant, dim=-1)[1]

        else:
            gate_stat = self.channels_firing_freq[task_id]
            aggregated_times = self.n_aggregations[task_id]
            idx_to_freeze = gate_stat > self.freezing_method.freeze_prob_thr * aggregated_times

        # aggregated mask becomes non-binary, but this does not interfere
        # with the logic of self.freeze_relevant_kernels()
        # and underlines the relevances of the kernels once more
        self.frozen_kernels_mask[idx_to_freeze] += 1

    def freeze_relevant_kernels(self, task_id):
#         from pdb import set_trace; set_trace()
        self.update_relevant_kernels(task_id)
        self.conv2d.freeze(self.frozen_kernels_mask.clamp(0, 1))

        r"""During training of task t only self.fbns[t] tracked relevant statistics, 
            therefore it is the only element to freeze"""
        if cfg.NORMALIZATION_IN_BACKBONE:
            self.fbns[task_id].freeze(self.frozen_kernels_mask.clamp(0, 1))


    def reinitialize_irrelevant_kernels(self):
        r"""
        Invoke all freezable classes to reinitialize unfrozen kernels
        """
        self.conv2d.reinit_unfrozen()

        r"""Despite only self.fbns[t] was properly frozen, all fbns should reinit irreevant parameters, 
            according to their is_frozen masks"""
        if cfg.NORMALIZATION_IN_BACKBONE:
            for fbn in self.fbns:
                fbn.reinit_unfrozen()

    def sample_channels_mask(self, logits):
        """
            Samples binary mask to select
            relevant output channel of the convolution

            Attributes:
            logits - logprobabilities of the bernoully variables
                for each output channel of the convolution to be selected
        """
        if self.training:
            if cfg.USE_GUMBEL_SIGMOID:
                channels_mask = gumbel_sigmoid(logits, tau=2/3)
            else:
                bernoully_logits = torch.stack([logits, -logits], dim=0)
                channels_mask = F.gumbel_softmax(bernoully_logits, tau=2/3, hard=True, dim=0)[0]
        else:
            channels_mask = (logits > 0).long()
        return channels_mask

    def select_channels_for_task(self, x, filters, gate_fc, task_idx):
        """
            Performs selection of the output channels for the given task.

            Attributes:
            x - input tensor
            filters - output tensor to be selected from
            gate_fc - sequential model, provides logprobabilities for each output channel of the convolution
            task_idx - int label of the task path; used for gate firing aggregation
        """
#         from pdb import set_trace; set_trace()
        logits = gate_fc(x)
        mask = self.sample_channels_mask(logits)
        self.taskwise_sparse_objective[task_idx] = mask.float().mean().reshape(1)

        if self.aggregate_firing:
            self.aggregate_channels_firing_stat(mask, task_idx)

        # expand last 2 dims for channel-level elementwise multiplication
        mask = mask[:, :, None, None]
        return filters * mask

    def forward(self, x):

        # task-agnostic conv2D over 5D tensor
        bs, N_tasks, N_ch, H, W = x.shape
        filters = self.main_conv_path(x.reshape(bs * N_tasks, N_ch, H, W))
        N_ch, H, W = filters.shape[-3:]
        filters = filters.reshape(bs, N_tasks, N_ch, H, W)

        after_bn = []
        for task_input, task_fbn in zip(filters.transpose(0, 1), self.fbns):
                after_bn.append(F.relu(task_fbn(task_input)))

        # permute batch_size and N_tasks dims for task-wise iterations
        x = x.transpose(0, 1)
        after_bn = torch.stack(after_bn, dim=0)

        output = []
        for task_idx, (task_input, task_filters, task_gate) in enumerate(zip(x, after_bn, self.gates)):
            selected = self.select_channels_for_task(
                task_input, task_filters, task_gate, task_idx)
            output.append(selected)

        # TODO: change torch.stack(smth).transpose(0, 1) to torch.stack(smth, 1)
        # permute back to batch_size x N_tasks x N_channels x H x W
        output = torch.stack(output).transpose(0, 1)
        # from pdb import set_trace; set_trace()

In the former class, I would like to change the layer’s initial weights and biases into the corresponding weights of the layer of a ResNet18 pretrained on ImageNet.
Is there a way to do it simply and efficiently?

One way would be to get the state_dict of the pretrained model and remap the keys in it to the new layer parameter names of your custom model.
However, since you’ve not only changed the attribute names but also added new attributes, directly loading this manipulated state_dict will fail and you would have to use strict=False which is also quite risky as it could ignore valid mismatches.

1 Like