Dimension Mismatch Issue in B-Spline Calculation within KANLinear Class

I’m working on a project where I apply Kolmogorov-Arnold Networks (KAN) on the Mixer architecture (KAN-Mixer). You can check the issue here

This is the code that I tried to run

X = torch.rand(64, 1, 28, 28) # MNIST (64, 1, 28, 28), CIFAR10 (64, 3, 32, 32)

pe = PatchEmbedding(1, 128, 4) # MNIST (1, 128, 4), CIFAR10 (3, 128, 4) 
t1 = Transformation1()
t2 = Transformation2()
ml = MixerLayer(128, 49, 256, 256) # MNIST (128, 49, 256, 256), CIFAR10 (128, 64, 256, 256)

print(f"Input shape: {X.shape}")
y1 = pe(X)
print(f"Patch Embedding output shape: {y1.shape}")
y2 = t2(y1)
print(f"T2 transformation output shape: {y2.shape}")
y3 = t1(y2)
print(f"T1 transformation output shape: {y3.shape}")
y4 = t1(y3)
print(f"T1 transformation output shape: {y4.shape}")
y5 = ml(y4)    # <--- Problem starts here
print(f"Mixer Layer output shape: {y5.shape}")

This is the Error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 6
      4 t1 = Transformation1()
      5 t2 = Transformation2()
----> 6 ml = MixerLayer(128, 49, 256, 256) # MNIST (128, 49, 256, 256), CIFAR10 (128, 64, 256, 256)
      8 print(f"Input shape: {X.shape}")
      9 y1 = pe(X)

File ~/model.py:265, in MixerLayer.__init__(self, embedding_dim, num_patch, token_intermediate_dim, channel_intermediate_dim, dropout)
    259 def __init__(self, embedding_dim, num_patch, token_intermediate_dim, channel_intermediate_dim, dropout=0.):
    260     super().__init__()
    262     self.token_mixer = nn.Sequential(
    263         nn.LayerNorm(embedding_dim),
    264         Transformation1(),
--> 265         KAN(num_patch, token_intermediate_dim, dropout),
    266         Transformation1()
    267     )
    269     self.channel_mixer = nn.Sequential(
    270         nn.LayerNorm(embedding_dim),
    271         KAN(embedding_dim, channel_intermediate_dim, dropout),
    272     )

File ~/model.py:223, in KAN.__init__(self, dim, intermediate_dim, dropout)
    220 def __init__(self, dim, intermediate_dim, dropout = 0.):
    221     super().__init__()
    222     self.kan = nn.Sequential(
--> 223         KANLinear(dim, intermediate_dim),
    224         KANLinear(intermediate_dim, dim),
    225     )

File ~/model.py:60, in KANLinear.__init__(self, in_features, out_features, grid_size, spline_order)
     56 self.spline_scaler = nn.Parameter(torch.Tensor(out_features, in_features))
     58 self.base_activation = nn.SiLU()
---> 60 self.reset_parameters()

File ~/model.py:73, in KANLinear.reset_parameters(self)
     71 # Compute the spline weight coefficients from the random noise
     72 grid_points = self.grid.T[self.spline_order : -self.spline_order]
---> 73 spline_coefficients = self.curve2coeff(grid_points, random_noise)
     75 # Copy the computed coefficients to the spline weight tensor
     76 self.spline_weight.data.copy_(spline_coefficients)

File ~/model.py:177, in KANLinear.curve2coeff(self, input_tensor, output_tensor)
    166 """
    167 Compute the coefficients of the curve that interpolates the given points.
    168 
   (...)
    174     torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
    175 """
    176 # Compute the B-spline bases for the input tensor
--> 177 b_splines_bases = self.b_splines(
    178     input_tensor
    179 )  # (batch_size, input_dim, grid_size + spline_order)
    181 # Transpose the B-spline bases and output tensor for matrix multiplication
    182 transposed_bases = b_splines_bases.transpose(
    183     0, 1
    184 )  # (input_dim, batch_size, grid_size + spline_order)

File ~/model.py:149, in KANLinear.b_splines(self, x)
    147 test = (input_tensor_expanded - expanded_grid[:, :, : -order - 1]) / (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
    148 print(f"left_term: {test.shape}")
--> 149 test2 = test * bases[:, :, :-1]
    151 left_term = (
    152     (input_tensor_expanded - expanded_grid[:, :, : -order - 1])
    153     / (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
    154 ) * bases[:, :, :-1]
    156 right_term = (
    157     (expanded_grid[:, :, order + 1 :] - input_tensor_expanded)
    158     / (expanded_grid[:, :, order + 1 :] - expanded_grid[:, :, 1:-order])
    159 ) * bases[:, :, 1:]

RuntimeError: The size of tensor a (10) must match the size of tensor b (11) at non-singleton dimension 2

Problem starts here

class KANLinear(nn.Module):
    # ...
    def b_splines(self, x: torch.Tensor):
        # ...

        # Compute the B-spline bases recursively
        for order in range(1, self.spline_order + 1):
            input_tensor_expanded = input_tensor_expanded.repeat(1, 1, expanded_grid[:, :, : -order - 1].shape[-1])
            test = (input_tensor_expanded - expanded_grid[:, :, : -order - 1]) / (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
            print(f"left_term: {test.shape}")
            test2 = test * bases[:, :, :-1] # <--- Problem starts here

            left_term = (
                (input_tensor_expanded - expanded_grid[:, :, : -order - 1])
                / (expanded_grid[:, :, order:-1] - expanded_grid[:, :, : -order - 1])
            ) * bases[:, :, :-1]

            # ...

        return bases.contiguous()

    # ...

Data’s dimension

  • expanded_grid and bases: torch.Size([6, 49, 12])
  • input_tensor_expanded: torch.Size([6, 49])

After computing the B-Spline bases

input_tensor_expanded, expanded_grid, and test: torch.Size([6, 49, 10])
bases: torch.Size([6, 49, 12])

Problem

I need to ensure that the dimensions of bases match those of input_tensor_expanded and expanded_grid to avoid the error I encountered. Currently, the dimension of test is torch.Size([6, 49, 10]) and the dimension of bases is torch.Size([6, 49, 12]). I need to figure out how to adjust the dimension of bases to match the others.

I have finally resolved the tensor dimension error! Check out the PR here

Problem

The tensor dimension after the transformation is not compatible with the original KANLinear format. The KANLinear class expects a 2D input, but I want to modify it to accept a 3D tensor.

Solution

I don’t have to change anything with KAN. I just need to modify the forward method to flatten the last two dimensions of the input, and then reshape the output to have the same shape as the input.

class KANLinear(nn.Module):
    # ... (keep the existing code here)

    def forward(self, x: torch.Tensor):
        # Save the original shape
        original_shape = x.shape

        # Flatten the last two dimensions of the input
        x = x.contiguous().view(-1, self.in_features)

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.spline_weight.view(self.out_features, -1),
        )

        # Apply the linear transformation
        output = base_output + spline_output

        # Reshape the output to have the same shape as the input
        output = output.view(*original_shape[:-1], -1)

        return output