On the equivalence of two implementations

Hello dear forum,

I am trying to understand if a couple of functions are equivalent. For instance, I have a reference implementation that uses F.unfold while I have decided to use nn.Unfold in mine, or the way the softmax is done in both cases.

The reference implementation is from github/leaderj1001.

The implementation is from the paper Stand-Alone Self-Attention in Vision Models.

# original
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

import math


class AttentionConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False):
        super(AttentionConv, self).__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups

        assert self.out_channels % self.groups == 0, "out_channels should be divided by groups. (example: out_channels: 40, groups: 4)"

        self.rel_h = nn.Parameter(torch.randn(out_channels // 2, 1, 1, kernel_size, 1), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn(out_channels // 2, 1, 1, 1, kernel_size), requires_grad=True)

        self.key_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
        self.query_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
        self.value_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

        self.reset_parameters()

    def forward(self, x):
        batch, channels, height, width = x.size()

        padded_x = F.pad(x, [self.padding, self.padding, self.padding, self.padding])
        q_out = self.query_conv(x)
        k_out = self.key_conv(padded_x)
        v_out = self.value_conv(padded_x)

        k_out = k_out.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
        v_out = v_out.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)

        k_out_h, k_out_w = k_out.split(self.out_channels // 2, dim=1)
        k_out = torch.cat((k_out_h + self.rel_h, k_out_w + self.rel_w), dim=1)

        k_out = k_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)
        v_out = v_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)

        q_out = q_out.view(batch, self.groups, self.out_channels // self.groups, height, width, 1)

        out = q_out * k_out
        out = F.softmax(out, dim=-1)
        out = torch.einsum('bnchwk,bnchwk -> bnchw', out, v_out).view(batch, -1, height, width)

        return out

    def reset_parameters(self):
        init.kaiming_normal_(self.key_conv.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.value_conv.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.query_conv.weight, mode='fan_out', nonlinearity='relu')

        init.normal_(self.rel_h, 0, 1)
        init.normal_(self.rel_w, 0, 1)
# my implementation
import torch
from torch import nn
from torch.nn import functional as F
from math import sqrt
from einops import rearrange
from einops.layers.torch import Rearrange
from opt_einsum import contract as einsum

class RelativeEmbeddings2d(nn.Module):
    def __init__(self, extent, embedding_size):
        super(RelativeEmbeddings2d, self).__init__()

        assert type(extent) == int, 'RelativeEmbeddings2d requires integer extent'

        self.extent = extent
        self.embedding_size = embedding_size
        self.width_mat = nn.Parameter(torch.randn((1, embedding_size // 2, 1, extent, 1)), requires_grad=True)
        self.height_mat = nn.Parameter(torch.randn((1, embedding_size // 2, extent, 1, 1)), requires_grad=True)

    def forward(self, x):
        x_h, x_w = rearrange(
            x, 'N (C K1 K2) L -> N C K1 K2 L',
            C=self.embedding_size, K1=self.extent
        ).split(self.embedding_size // 2, dim=1)
        return rearrange(
            torch.cat((x_h + self.height_mat, x_w + self.width_mat), dim=1),
            'N C K1 K2 L -> N (C K1 K2) L'
        )


class SASAConv2d(nn.Module):
    """Stand-alone Self-attention 2d"""

    # (W−F+2P)/S+1
    def __init__(self, in_channels, out_channels, kernel_size, heads=4, stride=1):
        super(SASAConv2d, self).__init__()

        assert heads > 0, 'SASAConv2d requires a positive number of heads'
        assert type(kernel_size) == int, 'SASAConv2d requires integer kernel_size'
        assert out_channels % heads == 0, 'SASAConv2d requires out_channels divisible by the number of heads'

        padding = (kernel_size - 1) // 2
        self.heads = heads
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.q_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Unfold(1, 1, 0, stride),
            Rearrange('N (M D) HW -> (N HW M) () D', M=self.heads)
        )
        self.q_conv.apply(init_weights)
        self.k_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Unfold(kernel_size, 1, padding, stride),
            RelativeEmbeddings2d(extent=kernel_size, embedding_size=out_channels),
            Rearrange('N (M D KK) HW -> (N HW M) D KK', M=self.heads, KK=self.kernel_size ** 2)
        )
        self.k_conv.apply(init_weights)
        self.v_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.Unfold(kernel_size, 1, padding, stride),
            Rearrange('N (M D KK) HW -> (N HW M) KK D', M=self.heads, KK=self.kernel_size ** 2)
        )
        self.v_conv.apply(init_weights)

    def forward(self, x):
        N, C, H, W = x.size()

        q = self.q_conv(x)
        k = self.k_conv(x)
        v = self.v_conv(x)

        weights = F.softmax(q.bmm(k), dim=-1)
        attn_maps = weights.bmm(v)
        return rearrange(attn_maps, '(N H W M) () D -> N (M D) H W', N=N, H=H, W=W)

Assuming you would like to compare both models, you could create a mapping between their state_dicts and load one to the other. Once this is done, perform a forward pass using the same data and check the output. I don’t see the usage of dropout layers, but you might want to call model.eval() in case you are using layers, which change their behavior in training and validation runs.
If the forward pass matches, you could then compare the gradients of all parameters (using the previous mapping for the state_dicts) after the backward() operation.

1 Like