CycleFC implementation crashes during training

Hello everyone,
I am trying to implement CycleFC module from this paper CycleMLP. Here is my code.

class CycleFC(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        sh: int,
        sw: int,
    ):
        """Initialize CycleFC layer

        Args:
            in_channels (int): Input_channels.
            out_channels (int): Ouput_channels.
            sh (int): Stepsize along height.
            sw (int): Stepsize along width.

        """
        super(CycleFC, self).__init__()
        self.W_mlp = nn.parameter.Parameter(
            torch.randn(size=(in_channels, out_channels)),
            requires_grad=True,
        )
        self.bias = nn.parameter.Parameter(torch.randn(size=(out_channels,)))

        self.sh = sh
        self.sw = sw
        self.in_channels = in_channels
        self.out_channels = out_channels

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward function for CycleFC

        Args:
            x (torch.Tensor): Input Tensor. Shape : b,  c_in, h, w

        Returns:
            torch.Tensor: Output Tensor. Shape : b, c_out, h, w
        """
        b, c_in, h, w = x.shape
        output = torch.zeros(size=(b, self.out_channels, h, w)).to(x.device)
        for k in range(b):
            for i in range(h):
                for j in range(w):
                    output[k, :, i, j] = self.calc_value(
                        x, w_mlp=self.W_mlp, c_in=c_in, i=i, j=j, h=h, w=w, k=k
                    )
                    output[k, :, i, j] += self.bias
        return output

    def get_offset(self, sh: int, sw: int, c: int) -> Tuple[int, int]:
        """Calculate offset based on stepsize along height and width for a given channel.

        Args:
            sh (int): Stepsize along height
            sw (int): Stepsize along width
            c (int): Current channel

        Returns:
            Tuple[int, int]: Returns offset along height and width for a given channel.
        """
        delta_i = c % sh - 1
        delta_j = floor(c / sh) % sw - 1
        return delta_i, delta_j

    def calc_value(
        self,
        x: torch.Tensor,
        w_mlp: torch.Tensor,
        c_in: int,
        i: int,
        j: int,
        h: int,
        w: int,
        k: int,
    ):
        """Calculate value for CycleFC

        Args:
            x (torch.Tensor): Input Tensor. Shape : b, c, h, w,
            w_mlp (torch.Tensor): Weight matrix.
            c_in (int): Input Channel
            i (int) : Current Height index
            j(int) : Current width index
            h(int) : Input height
            w(int) : Input weight
            k(int) : Batch size
        Returns:
            torch.Tensor: Output tensor
        """
        sum = 0
        for c in range(c_in):
            delta_i, delta_j = self.get_offset(self.sh, self.sw, c)
            i_offset = (i + delta_i) % h
            j_offset = (j + delta_j) % w
            sum += x[k, c, i_offset, j_offset] * w_mlp[c, :]
        return sum

This is a naive implementation of the formula provided in the paper for CycleFC. My training crashes, and it is extremely slow. How can I make this better? I have also attached a screenshot of the formula.