Hello dear forum,
I am trying to understand if a couple of functions are equivalent. For instance, I have a reference implementation that uses F.unfold
while I have decided to use nn.Unfold
in mine, or the way the softmax is done in both cases.
The reference implementation is from github/leaderj1001.
The implementation is from the paper Stand-Alone Self-Attention in Vision Models.
# original
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
class AttentionConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=False):
super(AttentionConv, self).__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.groups = groups
assert self.out_channels % self.groups == 0, "out_channels should be divided by groups. (example: out_channels: 40, groups: 4)"
self.rel_h = nn.Parameter(torch.randn(out_channels // 2, 1, 1, kernel_size, 1), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn(out_channels // 2, 1, 1, 1, kernel_size), requires_grad=True)
self.key_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
self.query_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
self.value_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
self.reset_parameters()
def forward(self, x):
batch, channels, height, width = x.size()
padded_x = F.pad(x, [self.padding, self.padding, self.padding, self.padding])
q_out = self.query_conv(x)
k_out = self.key_conv(padded_x)
v_out = self.value_conv(padded_x)
k_out = k_out.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
v_out = v_out.unfold(2, self.kernel_size, self.stride).unfold(3, self.kernel_size, self.stride)
k_out_h, k_out_w = k_out.split(self.out_channels // 2, dim=1)
k_out = torch.cat((k_out_h + self.rel_h, k_out_w + self.rel_w), dim=1)
k_out = k_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)
v_out = v_out.contiguous().view(batch, self.groups, self.out_channels // self.groups, height, width, -1)
q_out = q_out.view(batch, self.groups, self.out_channels // self.groups, height, width, 1)
out = q_out * k_out
out = F.softmax(out, dim=-1)
out = torch.einsum('bnchwk,bnchwk -> bnchw', out, v_out).view(batch, -1, height, width)
return out
def reset_parameters(self):
init.kaiming_normal_(self.key_conv.weight, mode='fan_out', nonlinearity='relu')
init.kaiming_normal_(self.value_conv.weight, mode='fan_out', nonlinearity='relu')
init.kaiming_normal_(self.query_conv.weight, mode='fan_out', nonlinearity='relu')
init.normal_(self.rel_h, 0, 1)
init.normal_(self.rel_w, 0, 1)
# my implementation
import torch
from torch import nn
from torch.nn import functional as F
from math import sqrt
from einops import rearrange
from einops.layers.torch import Rearrange
from opt_einsum import contract as einsum
class RelativeEmbeddings2d(nn.Module):
def __init__(self, extent, embedding_size):
super(RelativeEmbeddings2d, self).__init__()
assert type(extent) == int, 'RelativeEmbeddings2d requires integer extent'
self.extent = extent
self.embedding_size = embedding_size
self.width_mat = nn.Parameter(torch.randn((1, embedding_size // 2, 1, extent, 1)), requires_grad=True)
self.height_mat = nn.Parameter(torch.randn((1, embedding_size // 2, extent, 1, 1)), requires_grad=True)
def forward(self, x):
x_h, x_w = rearrange(
x, 'N (C K1 K2) L -> N C K1 K2 L',
C=self.embedding_size, K1=self.extent
).split(self.embedding_size // 2, dim=1)
return rearrange(
torch.cat((x_h + self.height_mat, x_w + self.width_mat), dim=1),
'N C K1 K2 L -> N (C K1 K2) L'
)
class SASAConv2d(nn.Module):
"""Stand-alone Self-attention 2d"""
# (W−F+2P)/S+1
def __init__(self, in_channels, out_channels, kernel_size, heads=4, stride=1):
super(SASAConv2d, self).__init__()
assert heads > 0, 'SASAConv2d requires a positive number of heads'
assert type(kernel_size) == int, 'SASAConv2d requires integer kernel_size'
assert out_channels % heads == 0, 'SASAConv2d requires out_channels divisible by the number of heads'
padding = (kernel_size - 1) // 2
self.heads = heads
self.kernel_size = kernel_size
self.out_channels = out_channels
self.q_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.Unfold(1, 1, 0, stride),
Rearrange('N (M D) HW -> (N HW M) () D', M=self.heads)
)
self.q_conv.apply(init_weights)
self.k_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.Unfold(kernel_size, 1, padding, stride),
RelativeEmbeddings2d(extent=kernel_size, embedding_size=out_channels),
Rearrange('N (M D KK) HW -> (N HW M) D KK', M=self.heads, KK=self.kernel_size ** 2)
)
self.k_conv.apply(init_weights)
self.v_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.Unfold(kernel_size, 1, padding, stride),
Rearrange('N (M D KK) HW -> (N HW M) KK D', M=self.heads, KK=self.kernel_size ** 2)
)
self.v_conv.apply(init_weights)
def forward(self, x):
N, C, H, W = x.size()
q = self.q_conv(x)
k = self.k_conv(x)
v = self.v_conv(x)
weights = F.softmax(q.bmm(k), dim=-1)
attn_maps = weights.bmm(v)
return rearrange(attn_maps, '(N H W M) () D -> N (M D) H W', N=N, H=H, W=W)