Hello everyone,
I am trying to implement CycleFC module from this paper CycleMLP. Here is my code.
class CycleFC(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
sh: int,
sw: int,
):
"""Initialize CycleFC layer
Args:
in_channels (int): Input_channels.
out_channels (int): Ouput_channels.
sh (int): Stepsize along height.
sw (int): Stepsize along width.
"""
super(CycleFC, self).__init__()
self.W_mlp = nn.parameter.Parameter(
torch.randn(size=(in_channels, out_channels)),
requires_grad=True,
)
self.bias = nn.parameter.Parameter(torch.randn(size=(out_channels,)))
self.sh = sh
self.sw = sw
self.in_channels = in_channels
self.out_channels = out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function for CycleFC
Args:
x (torch.Tensor): Input Tensor. Shape : b, c_in, h, w
Returns:
torch.Tensor: Output Tensor. Shape : b, c_out, h, w
"""
b, c_in, h, w = x.shape
output = torch.zeros(size=(b, self.out_channels, h, w)).to(x.device)
for k in range(b):
for i in range(h):
for j in range(w):
output[k, :, i, j] = self.calc_value(
x, w_mlp=self.W_mlp, c_in=c_in, i=i, j=j, h=h, w=w, k=k
)
output[k, :, i, j] += self.bias
return output
def get_offset(self, sh: int, sw: int, c: int) -> Tuple[int, int]:
"""Calculate offset based on stepsize along height and width for a given channel.
Args:
sh (int): Stepsize along height
sw (int): Stepsize along width
c (int): Current channel
Returns:
Tuple[int, int]: Returns offset along height and width for a given channel.
"""
delta_i = c % sh - 1
delta_j = floor(c / sh) % sw - 1
return delta_i, delta_j
def calc_value(
self,
x: torch.Tensor,
w_mlp: torch.Tensor,
c_in: int,
i: int,
j: int,
h: int,
w: int,
k: int,
):
"""Calculate value for CycleFC
Args:
x (torch.Tensor): Input Tensor. Shape : b, c, h, w,
w_mlp (torch.Tensor): Weight matrix.
c_in (int): Input Channel
i (int) : Current Height index
j(int) : Current width index
h(int) : Input height
w(int) : Input weight
k(int) : Batch size
Returns:
torch.Tensor: Output tensor
"""
sum = 0
for c in range(c_in):
delta_i, delta_j = self.get_offset(self.sh, self.sw, c)
i_offset = (i + delta_i) % h
j_offset = (j + delta_j) % w
sum += x[k, c, i_offset, j_offset] * w_mlp[c, :]
return sum
This is a naive implementation of the formula provided in the paper for CycleFC. My training crashes, and it is extremely slow. How can I make this better? I have also attached a screenshot of the formula.